{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 安装类库\n",
    "# !mkdir /home/aistudio/external-libraries\n",
    "# !pip install imgaug -t /home/aistudio/external-libraries\n",
    "import sys\n",
    "sys.path.append('/home/aistudio/external-libraries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shape: (32, 32, 3)\n",
      "label value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGylJREFUeJztnWlsXOd1ht9zZ+Um7pQoURIteZUVW04c17HjVFmcOGkAJ0VhJGgDA3WWAgnaoPljuECbAv2RAk2CoghSJKhrB0jjpHFSu47T2HGdOnYTWbIta7MsibQW7hSX4XAWznK//phhyuF7eDniSBQpnwcQxDm8c+937/DMved857yfOOdgGIaOd7kHYBhrGXMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAIwBzGMAGpyEBG5R0TeFJFTIvLgxRqUYawVZKUz6SISAnACwN0ABgDsB/Bp59yxpd7T0dHhent7V3Q8Yzn4c8zPzZEtlU6TrbFpg7rHcDhc+7BWgK/YisWCuu3cXJZsoTB/7+dylduNjYwjMZ2U5cZSyxW4DcAp51w/AIjIYwDuBbCkg/T29uLAgQM1HNJYkiI7w8jZPrLte/lVst31oXvUXba1d9Q+rmUoKrZ0ka3J2Un1/f19b5Cttb2BbGfPnqx4/eefe6iq8dXyiLUFwLkFrwfKtgpE5PMickBEDoyPj9dwOMNYfS55kO6c+45z7lbn3K2dnZ2X+nCGcVGp5RFrEMDWBa97yrYLwqqJLxxfeR6X/BTZkmP9ZHv+yZ/wdkl+jgeAP/nsZ9mofF6+r3yGylevAz/y55X3Dg2fJdvk9IA6xuFzR8nWf/I82RIzlddnLptS97eYWu4g+wFcIyJXiUgUwKcAPFnD/gxjzbHiO4hzriAiXwLwCwAhAA8759idDWMdU1Mezzn3NICnL9JYDGPNYTPphhHA5ZkJWgaRZedv3jZoKQxPlNmDYpLfm+G0eoOfI9vE8Ih67NGRUbKFhL9Tm1uayRaJRsjmK0G6czwtGOa3Il/MqGNs39hOttFxDtKH+4Yq95fPq/tbjN1BDCMAcxDDCMAcxDACMAcxjADWZJC+WmhVo87nor/CFAd9mcQsvzfKRXIbtmzWD64Eu6IErJ7Ps+Yzw+fIdvrIb8n21hvHeX9eVNkfz1wDwK+efpxsrZu3ku2OO+/iN4e5QnhiOkG2uVlOEGSzY2RzBU5CAMDYJFcLTE3z5+X8xde7ukSQ3UEMIwBzEMMIwBzEMAIwBzGMAN7WQTp8npE+f4oD27FXXiRbepIDzpEcf99ce9de9dDX3Hwr2bwIfxyHjx4m22vPP0+2pBK4z4zxTHgkHCNbdmKIbADw/M/OkO2G3/8I2d7zvg/yPud4xn5qjPfXv59L+UaHuBOyffs2dYxpn8vW82m+jlGvq+K1VPmnb3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIoKYsloicBpBESd6o4Jzj1MwaxmW5rGTiTc6gYHqGTG0hRcjM48xN/wvPqscOOy51iG/mTM33fvyfZDt64CDZdrRymUubx2NsUDJlxZDSgAGg/wRnt1488WOydffcSLa7bruBbOPH/5dsrz/zU7LNTbMARWpwlzrG+l3vYlsd63k1XdVa8Toaq04+4WKked/vnOPiF8O4ArBHLMMIoFYHcQCeEZFXROTz2gamrGisZ2p1kPc6594J4KMAvigi71u8gSkrGuuZWmV/Bsv/j4nIT1EStH7hgnZyGfUZvCj3RjR2cf/G+MBbZMuOs9JfQ5T7OWay+gke/61SvtK6nWzPPPMSb5fk3ogmr5ttrXGypeY4cD9+VhdtGEmxZMTABAfQ33/kX3m7g11kS59j4fKGIpeKxOq4HGYuxar0ALC9kQNyb+PVZMtK5Wcd0pQhFFZ8BxGRBhFpmv8ZwIcBHFnp/gxjLVLLHWQjgJ+WJXrCAP7NOfdfF2VUhrFGqEV6tB/AzRdxLIax5rA0r2EEcPn7QTTpwGoD96VWTqjy/U5ZYmzTO/immJ+dJlvf2TfJlp7kNHYuVqce+8QJXhkp1cjqgeE8n+TMBK+2lFBWVYpv58B9ZoqD7ENn9CB9PMdJjKZmVlE8e+p1su2b5CUVrungwDga4fObnmNbU5d+HYeHuA9mQ30bH6dtkQKjVLfsht1BDCMAcxDDCMAcxDACMAcxjAAue5CuxUpKJfgS772A9Q2VJRVEWR8vEuPZ5y233cn7UyZih1/lWe8eRYkQACbOs2DEoX2vka0uzIF7RxMHz3vv4jH+3s1cIv5P3/oW2ZIZLtMH9GuhKRymlVnu2FZelsB3HLiPjnErQbh1I9mkQS9Tev0otyckXmHhje4dOypep2b4uBp2BzGMAMxBDCMAcxDDCMAcxDACWPUgffGi85qH+krwnc1x/3hUmQkH9HX0PG16XQncC8r0fN8kdxRPKQHs3LW7yXbju+5Qx5g/y7PhP/rZL3m7DJeDf/KevWT7w49/mGwnT/HSAGMpTg7kXEgdY8TxttEwb9sU52vR0MJBdSLP59KwkWf7XR0vnTAwri9/UMxwEiOnaAg8/2RloXlymqsjNOwOYhgBmIMYRgDmIIYRgDmIYQSwbJAuIg8D+DiAMefc7rKtDcAPAfQCOA3gPucc11EvwncOc/nKWdu40hc+k+b1/17av49sGxob1ePccuNNZGuqqydbscj92YPjLJb2qxc5eH7rLK/rN6fMSMc296pjLCR5VnnsDC8PMJvka7Gzl2fnw+CAejrBwWrO5yC7UNRWawT8NAfGnuMSglCcP8OJSf5zGB3jZEedsq5jQzMnZBpbeDsAaFKSBnVhTrRs7WipeN13Tl/yYTHV3EEeAXDPItuDAJ5zzl0D4Lnya8O44ljWQZxzLwBYnJO8F8Cj5Z8fBfCJizwuw1gTrDQG2eicGy7/PIKSgIPKQuG48yYcZ6wzag7SnXMOSze/VgjHdZhwnLHOWOlM+qiIdDvnhkWkGwCv/K4gAsiioGpmloPQ/QdfJdvZ4UGyxaIsMAYAnW0sJnZd706yJWYmyHbwIAu6DZ8+RraRsxxwjk3xuRw8zIrmAHBbz/Vk27GJv0Cm2ri/urmDZ5/PDXFf+fAwB6KpJAfPLY16v3dqloP0mSmuANjR1UO2xjj/aaXrFGX5AidKiikeY9HTy9NzrVxWjzAnLJqbK88xHKru3rDSO8iTAO4v/3w/gCdWuB/DWNMs6yAi8gMAvwFwnYgMiMgDAL4G4G4ROQngQ+XXhnHFsewjlnPu00v8itf+NYwrDJtJN4wAVrXc3flAca4ygHpp38u03StHD5Ft5/UcCA6dS6jH+Y+nniPbxz+WJ1vfaRZv6zvHSu5eiMu5J5VZ4cGB02SLF9+tjvEdvb1k+7M//QzZtNnwnS0s3jY0xEmMk4c5uZCc4FR7c7sS6AIoFpQydmXSfUtrE9mcshyd+PzmkMcJ0FBIaUPI8+cHAGlF1C8U5pn9ol+ZDHDQqwcWY3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIYFWzWEW/iORsZebpv1/gXov2zVwqMpfl/okz/bpsvyiZkZcPserhESVbJsolCWmXKcw9C3s/uIdsXa1cKgIAhTRneXZfdx3ZPGW5goFfcJau7jxnc+5u4nUCN13LvTIHxofJBgDH67j3o7eHy1w6lbKSbJbLVLS+E9/n7JS2fmAsrJfD5JSelajS++NF9LKk5bA7iGEEYA5iGAGYgxhGAOYghhHAqgbp4gkiDZXBUnMbCy8MDrKk/aHXeQn2M6e4/wIAuns4oGvfxCUbvs+9CFOTvM+IEvT37lAC4M1ccpGZ00skclkO0ouK6EPmNJeQpE9zUJ1IcDBfp5SkvHsbl+x0x3jcALBhgvtJwq0snuBH+Dq6IgfaogTkxTwnX0SLpxWxidI+ufejMMf7jHqL329rFBpGzZiDGEYA5iCGEYA5iGEEsFJlxa8C+ByA+eaCh5xzTy+3r1Q6i32vVfZgFBXp/VCIh/VWP/dpDA7qQXpjK4sfFIutZEsmeW09LUi/Sglsuzo5SB8YOEG21rAusx+5kRMJ4QRL+Z87eJRsR2d4GYGfHePtEj4Hqy1xnmX+8HW3qmO8I8oKjudGT5Mt1MwBeaGeezrySvDsfE5MOJ8/fy3wBoBiUZmJd8qM/eKlMqpc33KlyooA8E3n3J7yv2WdwzDWIytVVjSMtwW1xCBfEpFDIvKwiPDzS5mFyoqJKlf1MYy1wkod5NsAdgLYA2AYwNeX2nChsmJzS8tSmxnGmmRFM+nOudH5n0XkuwCequZ9c7kM3jp9uHIAilR9VzuXu4vSZB+v02dXP/SBj5Dt+l07yFacYwXHrjZFOr97G9k623j2ecdWLlff1rlZHaMm7JcY4uUPJmZYtLIfHJg23cRl7IUMVw9MT7LQxRNnWNwBAG7s4tL2q7Rp7hFOLmSaeYbbFbhFoFDgIN3Pc9BfXGLmO53lpEq8QVlbsW7xuC/hTHpZbnSeTwLgOhDDuAKoJs37AwB7AXSIyACAvwGwV0T2oOSGpwF84RKO0TAuGytVVvyXSzAWw1hz2Ey6YQSwquXu0aiPzb2VAV1rB8/s5vMcuH3kD1ihcGKCg0MACMc5SMvleJ+33HIj2bIpDiSHlKUO9tzA793Zu51s0+d12f7hES4lnzw3QDbvat7nXe/fS7asx4HtzCxfnwJfGhx98zAbAZx98xTZukIc3G7wOIHifN7OE95OlJYDpwyysERMnVMUF8NFRZmxUHktnDLbrmF3EMMIwBzEMAIwBzGMAMxBDCOAVQ3Sk6kEXtj/8wpbQQnItvVyufqeO3aR7UyfLhznCQe7k7O8HqFf5Jn4ZIKDxokZDrRffp1npI/38ez64KAepMeV8u3rY7wMgdfAM/EjSln8S/t/TbaCEodGYlxmn5jVVx/ORfj6JOKcDAiHeLs0+PyKSv94aHEZOoCwYssraxkCgCf8HR8K83iyc5XJF19JIqj7r2orw3ibYg5iGAGYgxhGAOYghhHAqgbpsXgYO6+uDETzSrlz1yZtVphLwZMpvdExHOaS7HyR19tLJDmAzitTtm09nDSIxDhID8W5V3z79fp3kF9ke1OYg/xfv8jrKB49yWJyTU3cayOeorqe40qBiWn9OvqO3+8UtfqkokCfyXG/vwjPcEejvJ6gZsso6v4AEI7y34rn8bUtUILAgnTDqBlzEMMIwBzEMAIwBzGMAKrpKNwK4HsANqIU2XzHOfePItIG4IcAelHqKrzPOcfR2gIa6uK4dU9l3/asUpJ97NjrZJuc5l1fv2u3epymxg3amZBlbJwDtXyOt0tO8zJfMymefW5v26TYdMGX2Sx/N8VDHGiH6zlwL+b5mkWFVfLrG1mJ3VMSAdPj59QxtnT3kq01yn8yiUkWzPOFky+xGAffnhK4Fwpcwq61QABAg7LcWlEpIWhorFS69zxddJDGV8U2BQBfcc7tAnA7gC+KyC4ADwJ4zjl3DYDnyq8N44qiGuG4Yefcq+WfkwDeALAFwL0AHi1v9iiAT1yqQRrG5eKCYhAR6QVwC4B9ADY65+ZXchlB6RFMe8/vhOOmJ3mewDDWMlU7iIg0AngcwJedcxUzbM45hyVmXhYKx7W08TOxYaxlqnIQEYmg5Bzfd879pGwendfHKv/PCmeGsc6pJoslKMn8vOGc+8aCXz0J4H4AXyv//8Ry+yr6BSRmKwUQPHBZyEyCsxDHj3PW6FT//6jH6dnGyow37dlJtm3KdnUeZ8CcIgJQVPpYohHutRCuhAAA1Gf4httdz2O8ZQ9naTqaudzjpRdeIltiirWQtf6b8UH9u801cH9K8VoeI5TrowlnxMJ8MTIpLknxi9z7EY3r3+UhRXEzl1GUKRZXGlVXaVJVLdadAD4D4LCIHCzbHkLJMX4kIg8AOAPgvuoOaRjrh2qE416ENolQ4oMXdziGsbawmXTDCMAcxDACWNV+EE+A+milTzqfg6w7b38X2XbuvIFs/WdOq8cZG2fRhukJRSY/wgmC0QwnA1paOHBvauKSDRdRylRmuG8EANoaeN3Dzi7uO0lu5cB//29+Q7aJaVZ/9JVrqyHcKgMAaGvjX7Rt4XKYlPI1G1HEFKLachXC0XImw6U0ztOj6oKizKiddnrRPqu9NnYHMYwAzEEMIwBzEMMIwBzEMAJY1SAd4uCFKoMqL6LI6SsL03ds2kK2G3br6/9lsxzk+Yqq3/D5YbKNJTjYHZsZJdumbg6om5s5qPWX6DuYzfN300T2ZbINTrKwxJFjPGs+l+Vxx+NLRN+LaGjWA+CtbUrvR/Is2bwWPk5LhKsUfHBPhyqw4Pizmk3q1zHkKYG/sgAkTfYvNbO3CLuDGEYA5iCGEYA5iGEEYA5iGAGsapCezc3hxFDlunfNLTwjHctxYLohzs1WrcpsNgDEldJoDywY0NXK5dyRMM9czyR5dj3kOMqbmeby8tFxXnYBABKjrBR5qoPFKnqabyHbH9/3PrId3s/v1dZlbGllEYk5pUwfANw0VwEcOXaIbL2dLBjR3sAl+QVFCXNCKW3fEOHZeqeIOwDAbIIFNeL1/LdSv6FyjJ6nVzgsxu4ghhGAOYhhBGAOYhgBmIMYRgC1KCt+FcDnAMxHsA85554O2lfRL2J6tjIAzxZY1j6mLC2Qb2omW3J2KXU8LmWur+PArbG+m2zxKAecnc1c7p5X1A215RQGTg2pIwwrSxMcGmWFw3PKZPi1US79b1Ouz+YurjTwlPLwbL0eAE9EuFd9CzgxUhfmY9c1KIqQaT6ZfJFVFHNZXqIhn9PXKEwrypyxGB+7tbVS9TIUrk5jpJos1ryy4qsi0gTgFRF5tvy7bzrn/qGqIxnGOqSanvRhAMPln5MiMq+saBhXPLUoKwLAl0TkkIg8LCKqSvNCZcVUgm+nhrGWqUVZ8dsAdgLYg9Id5uva+xYqKzYoVbqGsZapaiZdU1Z0zo0u+P13ATy13H6ikTh6Nl5dYSsoUvWeUq6cyfCs8Ni0rvWrzXxv3c5LE6QVOf5skvfZ2KjMFLcrs/ARFnnbsV1f/6++kQPW/j4u3Y6FlSUMuvmatWzkRMLsLM8yh4ocAO+88WqyAYB/nMvO8wUedzymLEHg8RjbG3m7cITPeeo8Vx+Iz/oBAJDO8FNJOMbbeqHKP3VtvUSNZe8gSykrzsuOlvkkgCNVHdEw1hG1KCt+WkT2oJT6PQ3gC5dkhIZxGalFWTFwzsMwrgRsJt0wAljVcnfnisgVKoPgWIxLrRvquNy5WOCZ1HSClcEBoKGeA79ingPyyTSvexhX1uDTFNp9jwPYdI5n9rs2aeslAvX1HLBu2qSUiBf5OHM+zx63t3EPeCbB28UjnHAI1fN2ABAf54C8boTPx/M58C+Ckx1eiD/rugb+rNMpTshE4rrQW9FxQsYXDtwzhcoqB1/pe9ewO4hhBGAOYhgBmIMYRgDmIIYRwKoG6UW/iFS6cma54LNoWXKWhdpCwkGtCAe1ANDcxPZ0mvcZUZYEkzAH+KksB9/JIS5t12auoZwfADifM+chRR3e95VgV8m6F9PcIhAOcWCbSnNAnczpffPSzLP40sABfeo8B9V5JQgugI89l+HrmHccZA8MD6pjHBnjSoXOzZwMcOnKJE9RKfvXsDuIYQRgDmIYAZiDGEYA5iCGEYA5iGEEsLqlJr6HfKayVCE1y83z2kLyuRxnaaJKuQcATL3FJSgzKc6C7H7HtWRLjHBGxxO+TOoad0pm6q0+PfsSi3JWrqWNsy/Nrfwd1tzCZTPIcbYrrpSzJGZZJCOd5iwUALiMIvAQ4cxfHlx+4ucVgYYQfy75MGex0nnOTPWfZUELAEgm+G+gpYf7QQpe5Tk66NnFxdgdxDACMAcxjADMQQwjgGpabuMi8rKIvC4iR0Xkb8v2q0Rkn4icEpEfiojyYGwY65tqgvQ5AB9wzs2WxRteFJGfA/hLlITjHhORfwbwAEpKJ0uSz/kYGqgsx/CVwDYa4RKHwWEOnnM5XRAhrCxh0NLKgeTgsFLS4vF4PPD+6pW+Ck2VMRzTpY6OnzpOts1ZHmP4PJdnRCKcIGisZzXBhgZWPMxkOEgPRZfqteAAujHew9t5SsNMhktSpgp8vaWLy3MmZ/mzTs7qY8w6/o7vfScrT+6+ZXvF64OHn1H3t5hl7yCuxHwxUqT8zwH4AIAfl+2PAvhEVUc0jHVEVTGIiITKgg1jAJ4F0Adg2jk3nwccwBJqiwuF49KzejrRMNYqVTmIc67onNsDoAfAbQCur/YAC4Xj6hstTDHWFxeUxXLOTQN4HsB7ALSI/G4GrQeAPiNmGOuYapY/6ASQd85Ni0gdgLsB/D1KjvJHAB4DcD+AJ5bb19xcHn19w5X7V5YqaGpk28wU+3IyqT+y7drNsv+921kJcWDoNB+7iSWGXZ5nXesbOKCOKYF77zZdwa+tjWeas1meaZ5W1glMTClqlG3Kun557m3xPD5uInVeHWOuyLPz0wkWSdiQ4hn7mBI8Zz3eXyzK2yWSSh9LSv8ub97CTyXxTkW0o7EyOeGUXhmNarJY3QAeFZEQSnecHznnnhKRYwAeE5G/A/AaSuqLhnFFUY1w3CGUFN0X2/tRikcM44rFZtINIwBzEMMIQJyrruz3ohxMZBzAGQAdAPTIcP1h57I2We5ctjvnOpfbyao6yO8OKnLAOXfrqh/4EmDnsja5WOdij1iGEYA5iGEEcLkc5DuX6biXAjuXtclFOZfLEoMYxnrBHrEMIwBzEMMIYNUdRETuEZE3y626D6728WtBRB4WkTERObLA1iYiz4rIyfL/XO24BhGRrSLyvIgcK7dS/0XZvu7O51K2ha+qg5QLHr8F4KMAdqG0Uu6u1RxDjTwC4J5FtgcBPOecuwbAc+XX64ECgK8453YBuB3AF8ufxXo8n/m28JsB7AFwj4jcjlLV+Tedc1cDmEKpLfyCWO07yG0ATjnn+p1zOZRK5e9d5TGsGOfcCwAWN8Lfi1LLMbCOWo+dc8POuVfLPycBvIFSV+i6O59L2Ra+2g6yBcBCibwlW3XXERudc/NNLiMANl7OwawEEelFqWJ7H9bp+dTSFh6EBekXEVfKma+rvLmINAJ4HMCXnauUMVlP51NLW3gQq+0ggwC2Lnh9JbTqjopINwCU/2ex4TVKWcbpcQDfd879pGxet+cDXPy28NV2kP0ArilnF6IAPgXgyVUew8XmSZRajoEqW4/XAiIiKHWBvuGc+8aCX6278xGRThFpKf883xb+Bv6/LRxY6bk451b1H4CPATiB0jPiX6328Wsc+w8ADAPIo/RM+wCAdpSyPScB/BJA2+UeZ5Xn8l6UHp8OAThY/vex9Xg+AG5Cqe37EIAjAP66bN8B4GUApwD8O4DYhe7bSk0MIwAL0g0jAHMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAL4P/reBAlsXKWPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 读取数据\n",
    "reader = paddle.batch(\n",
    "    paddle.dataset.cifar.train100(),\n",
    "    batch_size=8) # 数据集读取器\n",
    "data = next(reader()) # 读取数据\n",
    "index = 0 # 批次索引\n",
    "\n",
    "# 读取图像\n",
    "image = np.array([x[0] for x in data]).astype(np.float32) # 读取图像数据，数据类型为float32\n",
    "image = image * 255 # 从[0,1]转换到[0,255]\n",
    "image = image[index].reshape((3, 32, 32)).transpose((1, 2, 0)).astype(np.uint8) # 数据格式从CHW转换为HWC，数据类型转换为uint8\n",
    "print('image shape:', image.shape)\n",
    "\n",
    "# 图像增强\n",
    "# sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "# seq = iaa.Sequential([\n",
    "#     sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "#     iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "# image = seq(image=image)\n",
    "\n",
    "# 读取标签\n",
    "label = np.array([x[1] for x in data]).astype(np.int64) # 读取标签数据，数据类型为int64\n",
    "vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "         'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "         'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "         'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "         'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "         'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "         'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "         'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "         'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "         'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "         'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "         'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "         'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "         'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "         'baby', 'boy', 'girl', 'man', 'woman',\n",
    "         'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "         'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "         'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "         'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "         'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "vlist.sort() # 字母上升排序\n",
    "print('label value:', vlist[label[index]])\n",
    "\n",
    "# 显示图像\n",
    "image = Image.fromarray(image)   # 转换图像格式\n",
    "image.save('./work/out/img.png') # 保存读取图像\n",
    "plt.figure(figsize=(3, 3))       # 设置显示大小\n",
    "plt.imshow(image)                # 设置显示图像\n",
    "plt.show()                       # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n",
      "valid_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 训练数据增强\n",
    "def train_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 增强图像\n",
    "    sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "    seq = iaa.Sequential([\n",
    "        sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "        iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "    images = seq(images=images)\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 验证数据增强\n",
    "def valid_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 读取训练数据\n",
    "train_reader = paddle.batch(\n",
    "    paddle.reader.shuffle(paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "train_data = next(train_reader()) # 读取训练数据\n",
    "\n",
    "train_image = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取训练图像\n",
    "train_image = train_augment(train_image)                                                       # 训练图像增强\n",
    "train_label = np.array([x[1] for x in train_data]).reshape((-1, 1)).astype(np.int64)           # 读取训练标签\n",
    "print('train_data: image shape {}, label shape:{}'.format(train_image.shape, train_label.shape))\n",
    "\n",
    "# 读取验证数据\n",
    "valid_reader = paddle.batch(\n",
    "    paddle.dataset.cifar.test100(),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "valid_data = next(valid_reader()) # 读取验证数据\n",
    "\n",
    "valid_image = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取验证图像\n",
    "valid_image = valid_augment(valid_image)                                                       # 验证图像增强\n",
    "valid_label = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64)           # 读取验证标签\n",
    "print('valid_data: image shape {}, label shape:{}'.format(valid_image.shape, valid_label.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型设计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm\n",
    "import math\n",
    "\n",
    "# 模组结构：输入维度，输出维度，滑动步长，基础长度, 队列长度\n",
    "group_arch = [(3, 256, 1, 2, 3), (256, 512, 2, 2, 3), (512, 1024, 2, 2, 3)]\n",
    "group_dim  = 1024 # 模组输出维度\n",
    "class_dim  = 100  # 类别数量维度\n",
    "\n",
    "# 卷积单元\n",
    "class ConvUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=3, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化卷积单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ConvUnit, self).__init__()\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=filter_size,\n",
    "            stride=stride,\n",
    "            padding=(filter_size-1)//2,                       # 输出特征图大小不变\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 投影单元\n",
    "class ProjUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=1, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化投影单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ProjUnit, self).__init__()\n",
    "        \n",
    "        # 添加池化\n",
    "        self.pool = Pool2D(\n",
    "            pool_size=filter_size,\n",
    "            pool_stride=stride,\n",
    "            pool_padding=0,\n",
    "            pool_type='avg')\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=1,\n",
    "            stride=1,\n",
    "            padding=0,\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行池化卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行池化\n",
    "        x = self.pool(x)\n",
    "        \n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 队列结构\n",
    "class SSRQueue(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=2, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化队列结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长，1保持不变，2下采样\n",
    "            queues  - 队列长度，分割尺度为2^(n-1)\n",
    "            act     - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRQueue, self).__init__()\n",
    "        \n",
    "        # 添加队列变量\n",
    "        self.queues = queues # 队列长度\n",
    "        self.split_list = [] # 分割列表\n",
    "        \n",
    "        # 添加队列列表\n",
    "        self.queue_list = [] # 队列列表\n",
    "        for i in range(queues):\n",
    "            # 添加队列项目\n",
    "            queue_item = self.add_sublayer( # 构造队列项目\n",
    "                'queue_' + str(i),\n",
    "                ConvUnit(\n",
    "                    in_dim=(in_dim if i==0 else out_dim), # 每组队列项目除第一个外，in_dim=out_dim\n",
    "                    out_dim=out_dim,\n",
    "                    filter_size=3,\n",
    "                    stride=(stride if i==0 else 1), # 每组队列项目除第一块外，stride=1\n",
    "                    act=act))\n",
    "            self.queue_list.append(queue_item) # 添加队列项目\n",
    "            \n",
    "            # 计算输出维度\n",
    "            if i < (queues-1): # 如果不是最后一项\n",
    "                out_dim = out_dim//2 # 输出维度减半\n",
    "                self.split_list.append(out_dim) # 添加分割列表\n",
    "            \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x_list = [] # 队列输出列表\n",
    "        for i, queue_item in enumerate(self.queue_list):\n",
    "            if i < (self.queues-1): # 如果不是最后一项\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_item, x = fluid.layers.split(input=x, num_or_sections=[-1, self.split_list[i]], dim=1)\n",
    "                x_list.append(x_item) # 添加输出列表\n",
    "            else: # 否则不对特征分割\n",
    "                x = queue_item(x) # 提取队列特征\n",
    "                x_list.append(x) # 添加输出列表\n",
    "        \n",
    "        # 联结特征\n",
    "        x = fluid.layers.concat(input=x_list, axis=1) # 队列输出列表按通道维进行特征联结\n",
    "        \n",
    "        return x\n",
    "    \n",
    "# 基础结构\n",
    "class SSRBasic(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=1, is_pass=True):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化基础结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            queues  - 队列长度\n",
    "            is_pass - 是否直连\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBasic, self).__init__()\n",
    "        \n",
    "        # 是否直连标识\n",
    "        self.is_pass = is_pass\n",
    "        \n",
    "        # 添加投影路径\n",
    "        self.proj = ProjUnit(in_dim=in_dim, out_dim=out_dim, filter_size=stride, stride=stride, act=None)\n",
    "        \n",
    "        # 添加卷积路径\n",
    "        if queues==1:\n",
    "            self.conv = ConvUnit(in_dim=in_dim, out_dim=out_dim, filter_size=3, stride=stride, act='relu')\n",
    "        else:\n",
    "            self.conv = SSRQueue(in_dim=in_dim, out_dim=out_dim, stride=stride, queues=queues, act='relu')\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "            y - 输出特征\n",
    "        \"\"\"\n",
    "        # 直连路径\n",
    "        if self.is_pass: # 是否直连\n",
    "            x_pass = x\n",
    "        else:            # 否则投影\n",
    "            x_pass = self.proj(x)\n",
    "        \n",
    "        # 卷积路径\n",
    "        x_conv = self.conv(x)\n",
    "        \n",
    "        # 输出特征\n",
    "        x = fluid.layers.elementwise_add(x=x_pass, y=x_conv, act=None) # 直连路径与卷积路径进行特征相加\n",
    "        y = x\n",
    "        \n",
    "        return x, y\n",
    "    \n",
    "# 模块结构\n",
    "class SSRBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, basics=1, queues=1):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模块结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            basics  - 基础长度\n",
    "            queues  - 队列长度\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBlock, self).__init__()\n",
    "        \n",
    "        # 添加模块列表\n",
    "        self.block_list = [] # 模块列表\n",
    "        for i in range(basics):\n",
    "            block_item = self.add_sublayer( # 构造模块项目\n",
    "                'block_' + str(i),\n",
    "                SSRBasic(\n",
    "                    in_dim=(in_dim if i==0 else out_dim), # 每组模块项目除第一块外，输入维度=输出维度\n",
    "                    out_dim=out_dim,\n",
    "                    stride=(stride if i==0 else 1), # 每组模块项目除第一块外，stride=1\n",
    "                    queues=queues,\n",
    "                    is_pass=(False if i==0 else True))) # 每组模块项目除第一块外，is_pass=True\n",
    "            self.block_list.append(block_item) # 添加模块项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模块输出列表\n",
    "        for block_item in self.block_list:\n",
    "            x, y_item = block_item(x) # 提取模块特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "\n",
    "# 模组结构\n",
    "class SSRGroup(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模组结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRGroup, self).__init__()\n",
    "        \n",
    "        # 添加模组列表\n",
    "        self.group_list = [] # 模组列表\n",
    "        for i, block_arch in enumerate(group_arch):\n",
    "            group_item = self.add_sublayer( # 构造模组项目\n",
    "                'group_' + str(i),\n",
    "                SSRBlock(\n",
    "                    in_dim=block_arch[0],\n",
    "                    out_dim=block_arch[1],\n",
    "                    stride=block_arch[2],\n",
    "                    basics=block_arch[3],\n",
    "                    queues=block_arch[4]))\n",
    "            self.group_list.append(group_item) # 添加模组项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模组输出列表\n",
    "        for group_item in self.group_list:\n",
    "            x, y_item = group_item(x) # 提取模组特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "        \n",
    "# 分割网络\n",
    "class SSRNet(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化分割网络，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRNet, self).__init__()\n",
    "        \n",
    "        # 添加模组结构\n",
    "        self.backbone = SSRGroup() # 输出：N*C*H*W\n",
    "        \n",
    "        # 添加全连接层\n",
    "        self.pool = Pool2D(global_pooling=True, pool_type='avg') # 输出：N*C*1*1\n",
    "        \n",
    "        stdv = 1.0/(math.sqrt(group_dim)*1.0)                    # 设置均匀分布权重方差\n",
    "        self.fc = Linear(                                        # 输出：=N*10\n",
    "            input_dim=group_dim,\n",
    "            output_dim=class_dim,\n",
    "            param_attr=fluid.initializer.Uniform(-stdv, stdv),   # 使用均匀分布初始权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),           # 使用常量数值初始偏置\n",
    "            act='softmax')\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入图像进行分类\n",
    "        输入:\n",
    "            x - 输入图像\n",
    "        输出:\n",
    "            x - 预测结果\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x, y_list = self.backbone(x)\n",
    "        \n",
    "        # 进行预测\n",
    "        x = self.pool(x)\n",
    "        x = fluid.layers.reshape(x, [x.shape[0], -1])\n",
    "        x = self.fc(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tatol param: 28207204\n",
      "infer shape: [1, 100]\n"
     ]
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.base import to_variable\n",
    "import numpy as np\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 输入数据\n",
    "    x = np.random.randn(1, 3, 32, 32).astype(np.float32)\n",
    "    x = to_variable(x)\n",
    "    \n",
    "    # 进行预测\n",
    "    backbone = SSRNet() # 设置网络\n",
    "    \n",
    "    infer = backbone(x) # 进行预测\n",
    "    \n",
    "    # 显示结果\n",
    "    parameters = 0\n",
    "    for p in backbone.parameters():\n",
    "        parameters += np.prod(p.shape) # 统计参数\n",
    "    \n",
    "    print('tatol param:', parameters)\n",
    "    print('infer shape:', infer.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd4HNW9//H3d4u0qlax3IvcwL0KY2PAGAyY6svFJHZMD+GGNEpIMCQ/AoRcHHITAqQQigMEsENIqAYcWmKCwcY27gV3LLlJsq0ubTu/P86qWmUlS1qt9H09jx7NzszOnN2RPnPmzJkZMcaglFKq63BEugBKKaXalwa/Ukp1MRr8SinVxWjwK6VUF6PBr5RSXYwGv1JKdTEa/Eop1cVo8CulVBejwa+UUl2MK1Ir7t69u8nMzIzU6pVSKiqtWbMmzxiTcTLLiFjwZ2Zmsnr16kitXimlopKI7DvZZWhTj1JKdTEa/Eop1cVo8CulVBcTsTb++vh8PrKzsykvL490UaKex+OhX79+uN3uSBdFKdXBdKjgz87OJikpiczMTEQk0sWJWsYY8vPzyc7OZtCgQZEujlKqg+lQTT3l5eWkp6dr6J8kESE9PV2PnJRS9epQwQ9o6LcS/R6VUg3pcMEfjuOlXvyBYKSLoZRSUSnqgt/rD/DV0VK+Olra6svOz89n/PjxjB8/nl69etG3b9+q116vN6xl3HDDDWzfvj3sdT799NPcdtttLS2yUko1W4c6uRuOYOjZ8P5g6z8kPj09nXXr1gFw3333kZiYyJ133llrHmMMxhgcjvr3mX/+859bvVxKKdWaoq7Gb4wN/PZswd65cycjR45k/vz5jBo1ioMHD3LzzTeTlZXFqFGjeOCBB6rmPfPMM1m3bh1+v5+UlBQWLFjAuHHjmDp1KkeOHGl0PXv27GHGjBmMHTuW888/n+zsbACWLFnC6NGjGTduHDNmzABg48aNnHbaaYwfP56xY8eye/futvsClFKdSoet8d//5ma2HCg8YXzQGMq8ARwOIc7tbNYyR/ZJ5meXjWpRebZt28bzzz9PVlYWAAsXLiQtLQ2/38+MGTOYM2cOI0eOrPWegoICpk+fzsKFC7njjjtYtGgRCxYsaHAd3/nOd7jpppuYP38+Tz75JLfddhuvvPIK999/P//617/o2bMnx48fB+APf/gDd955J1//+tepqKio2iEqpVRToq7GHylDhgypCn2AxYsXM3HiRCZOnMjWrVvZsmXLCe+Ji4vjoosuAmDSpEns3bu30XWsXLmSuXPnAnDttdfy8ccfAzBt2jSuvfZann76aYJBe1L7jDPO4MEHH+Thhx9m//79eDye1viYSqkuoMPW+BuqmReX+9idV0JCrIshGYntVp6EhISq4R07dvDoo4+yatUqUlJSuPrqq+vtMx8TE1M17HQ68fv9LVr3U089xcqVK3nrrbeYOHEiX3zxBddccw1Tp05l6dKlzJo1i0WLFnH22We3aPlKqa4l6mr8led0I9lLvbCwkKSkJJKTkzl48CDLli1rleVOmTKFl19+GYAXXnihKsh3797NlClT+PnPf05qaio5OTns3r2boUOHcuutt3LppZeyYcOGVimDUqrz67A1/oZUndyN4AVKEydOZOTIkQwfPpyBAwcybdq0Vlnu73//e2688UYeeughevbsWdVD6Pbbb2fPnj0YY7jgggsYPXo0Dz74IIsXL8btdtOnTx/uu+++VimDUqrzk3BPCoqIE1gN5BhjLq0zLRZ4HpgE5ANfN8bsbWx5WVlZpu6DWLZu3cqIESMaLUeZL8COw0WkxsfQPy0+rLJ3VeF8n0qp6CIia4wxWU3P2bDmNPXcCmxtYNo3gWPGmKHAI8AvT6ZQjYlzO3GK4HToLQmUUqolwgp+EekHXAI83cAss4HnQsOvAOdJW7bFCGjnRaWUaplwa/y/BX4MNHSDnL7AfgBjjB8oANJPunQNEAS037pSSrVIk8EvIpcCR4wxa052ZSJys4isFpHVubm5J7EcrfErpVRLhVPjnwZcLiJ7gSXAuSLyQp15coD+ACLiArphT/LWYox50hiTZYzJysjIOKmCa4VfKaVapsngN8bcbYzpZ4zJBOYCHxpjrq4z2xvAdaHhOaF52iya9bSuUkq1XIsv4BKRB0Tk8tDLZ4B0EdkJ3AE0fEOaVuANBDlWGt5tkptjxowZJ1yM9dvf/pZbbrml0fclJtoriA8cOMCcOXPqneecc86hbvfVxsYrpVRbaVbwG2P+VdmH3xhzrzHmjdBwuTHmKmPMUGPMZGNMVN4qct68eSxZsqTWuCVLljBv3ryw3t+nTx9eeeWVtiiaUkq1mqi7ZUNbmjNnDkuXLq166MrevXs5cOAAZ511FsXFxZx33nlMnDiRMWPG8Prrr5/w/r179zJ69GgAysrKmDt3LiNGjOCKK66grKysyfUvXryYMWPGMHr0aO666y4AAoEA119/PaNHj2bMmDE88sgjADz22GOMHDmSsWPHVt3YTSmlwtFxb9nwzgI4tLHeSYMrQjc7i21m8XuNgYsWNjg5LS2NyZMn88477zB79myWLFnC1772NUQEj8fDq6++SnJyMnl5eUyZMoXLL7+8wVtH/PGPfyQ+Pp6tW7eyYcMGJk6c2GjRDhw4wF133cWaNWtITU3lggsu4LXXXqN///7k5OSwadMmgKrbMi9cuJA9e/YQGxtbNU4ppcIRlTX+GJcttmmDTp01m3tqNvMYY7jnnnsYO3YsM2fOJCcnh8OHDze4nOXLl3P11fYc+NixYxk7dmyj6/38888555xzyMjIwOVyMX/+fJYvX87gwYPZvXs33//+93n33XdJTk6uWub8+fN54YUXcLk67v5bKdXxdNzEaKRmfrywnEOF5Yzu263Vb9Y2e/Zsbr/9dtauXUtpaSmTJk0C4MUXXyQ3N5c1a9bgdrvJzMys91bMrS01NZX169ezbNkynnjiCV5++WUWLVrE0qVLWb58OW+++Sa/+MUv2Lhxo+4AlFJhicoaf1V/zjboMJqYmMiMGTO48cYba53ULSgooEePHrjdbj766CP27dvX6HLOPvtsXnrpJQA2bdrU5G2TJ0+ezL///W/y8vIIBAIsXryY6dOnk5eXRzAY5Morr+TBBx9k7dq1BINB9u/fz4wZM/jlL39JQUEBxcXFJ//hlVJdQlRWESWU/G11ocC8efO44ooravXwmT9/PpdddhljxowhKyuL4cOHN7qMW265hRtuuIERI0YwYsSIqiOHhvTu3ZuFCxcyY8YMjDFccsklzJ49m/Xr13PDDTdUPXnroYceIhAIcPXVV1NQUIAxhh/84AekpKSc/AdXSnUJYd+WubW19LbMAHlFFRwoKGNk72Rczug8aGkPeltmpTqf9r4tc8ehl+4qpVSLRWXwe7zH6C+5eqM2pZRqgQ4X/OE0PbmCFSRTqjdqa0SkmvCUUh1fhwp+j8dDfn5+k6FlxImDIHpz5voZY8jPz8fj8US6KEqpDqhD9erp168f2dnZNHWvfl/Jcdy+QvzHtuJyOtupdNHF4/HQr1+/SBdDKdUBdajgd7vdDBo0qMn5Nr76K0asf5Dd161jcBjzK6WUqtahmnrCVeGIB+BP7zd+UZRSSqkTRWXwHyyzzTsbdudEuCRKKRV9ojL4xwzuC0Df+Iae/a6UUqohURn8mb16ADAwSXv1KKVUczUZ/CLiEZFVIrJeRDaLyP31zHO9iOSKyLrQz01tU9yQmAT7K1japqtRSqnOKJxePRXAucaYYhFxA/8RkXeMMZ/Vme+vxpjvtX4R6xEKfneg6adaKaWUqq3J4Df2aqrKe/66Qz+RbWOJsQ83jw1q8CulVHOF1cYvIk4RWQccAd4zxqysZ7YrRWSDiLwiIv1btZR1VTX1aPArpVRzhRX8xpiAMWY80A+YLCKj68zyJpBpjBkLvAc8V99yRORmEVktIqubujq3Ue44ggix2savlFLN1qxePcaY48BHwKw64/ONMRWhl08D9T51xBjzpDEmyxiTlZGR0ZLyWiJ4HXHEBtv+0YdKKdXZhNOrJ0NEUkLDccD5wLY68/Su8fJyYGtrFrI+XkccsUabepRSqrnC6dXTG3hORJzYHcXLxpi3ROQBYLUx5g3gByJyOeAHjgLXt1WBK3kdccRprx6llGq2cHr1bAAm1DP+3hrDdwN3t27RGudzxuPxalOPUko1V1ReuQs2+OO0qUcppZotioM/jlhTzqacgkgXRSmlokrUBr/XGU8iZVz1xKeRLopSSkWVKA7+OOKlgnJ/INJFUUqpqBK1we9zxJNAmT5wXSmlmilqg7/ClUyylOFEa/xKKdUcURv8Can2nvyDE30RLolSSkWXqA3+ccMyAbhsmCeyBVFKqSgTtcEv8ekAxPqOR7gkSikVXaI2+IlPBcDjL4xwQZRSKrpEb/DHpQHg8ekFXEop1RxRHPyVNX4NfqWUao7oDX5PN/w4iNOmHqWUapboDX4RiiSJeK3xK6VUs0Rv8ANFkkh8QINfKaWaI8qDP5l4bepRSqlmie7gdyQRHyyKdDGUUiqqhPPMXY+IrBKR9SKyWUTur2eeWBH5q4jsFJGVIpLZFoWtq9iRTGJAa/xKKdUc4dT4K4BzjTHjgPHALBGZUmeebwLHjDFDgUeAX7ZuMetX7EgiQWv8SinVLE0Gv7GKQy/doZ+6N0OeDTwXGn4FOE9EpNVK2YByRwKxphwC/rZelVJKdRphtfGLiFNE1gFHgPeMMSvrzNIX2A9gjPEDBUB6axa0PhWOODvgLW58RqWUUlXCCn5jTMAYMx7oB0wWkdEtWZmI3Cwiq0VkdW5ubksWUUu5I94OVGhzj1JKhatZvXqMMceBj4BZdSblAP0BRMQFdAPy63n/k8aYLGNMVkZGRstKXEPQnWgHtMavlFJhC6dXT4aIpISG44DzgW11ZnsDuC40PAf40Ji2fyiiKy7ZDlRo8CulVLhcYczTG3hORJzYHcXLxpi3ROQBYLUx5g3gGeAvIrITOArMbbMS1+CO72YHKrRLp1JKhavJ4DfGbAAm1DP+3hrD5cBVrVu0psXG2xq/r6wId3uvXCmlolRUX7nriEsCwF+mNX6llApXVAe/xNrgD5Rr8CulVLg6RfCv2LI3sgVRSqkoEtXB746Nw2uc7Mo+FOmiKKVU1Ijq4I9xOSgmjkTKaIfeo0op1SlEdfDHuhyUmDgSpJyiCr1fj1JKhSPKg99ZVeP3+oORLo5SSkWFqA5+29TjIZEy/AFt6lFKqXBEdfDXbOrxBbTGr5RS4Yju4HdXn9z1avArpVRYojr4Y5wOio2HRNGmHqWUCldUB3+s20kJcSSgTT1KKRWu6A7+0MndeCrw+rU7p1JKhSOqgz/G5aDUeHCIIVhRFuniKKVUVIjq4I91OSglFoCgPoVLKaXCEtXBH+O0NX6AgD6FSymlwhLVwS8iVTV+KkoiWxillIoS4Txzt7+IfCQiW0Rks4jcWs8854hIgYisC/3cW9+y2sLlpw0DwHg1+JVSKhzhPHPXD/zQGLNWRJKANSLynjFmS535PjbGXNr6RWzc2MF9YT2gwa+UUmFpssZvjDlojFkbGi4CtgJ927pg4XLGJtgBDX6llApLs9r4RSQT++D1lfVMnioi60XkHREZ1QplC4vTY5/CJT4NfqWUCkfYwS8iicDfgduMMXUfcrsWGGiMGQc8DrzWwDJuFpHVIrI6Nze3pWWuxelJBOCt1TtbZXlKKdXZhRX8IuLGhv6Lxph/1J1ujCk0xhSHht8G3CLSvZ75njTGZBljsjIyMk6y6FZl8MdR0SrLU0qpzi6cXj0CPANsNcb8poF5eoXmQ0Qmh5ab35oFbUhMKPgTKG+P1SmlVNQLp1fPNOAaYKOIrAuNuwcYAGCMeQKYA9wiIn6gDJhr2ukhuC63izITQ7xo8CulVDiaDH5jzH8AaWKe3wG/a61CNYfLIRQSS7w29SilVFii+spdCF29azxa41dKqTBFffADlBJLAhUEg/owFqWUakonCX4P8ZTr4xeVUioMnSL4iUkgXiqo8GvwK6VUUzpF8HdPSyXL8SXO174N7dOZSCmlolanCP7k5BQAEre/AgFfhEujlFIdW+cI/riY6he+0sgVRCmlokCnCH7JWVP9wq/dOpVSqjGdIviZcHX1sE8fuq6UUo3pHMF/1h0sGXCfHdbgV0qpRnWO4AeCbvtAFqPBr5RSjeo0wV/kt7cd+tXS9RSWa88epZRqSKcJ/qNeJwCb9h3iuy+ujXBplFKq4+o0wY/bA0AcXnYeKY5wYZRSquPqNMFfge3LH4tXL95VSqlGdJrgLw8Fv0e8GDT5lVKqIZ0o+GMB8OCNcEmUUqpjC+eZu/1F5CMR2SIim0Xk1nrmERF5TER2isgGEZnYNsVtWGWNP06bepRSqlHhPHPXD/zQGLNWRJKANSLynjFmS415LgKGhX5OB/4Y+t1uKtv4PXi1oUcppRrRZI3fGHPQGLM2NFwEbAX61pltNvC8sT4DUkSkd6uXthFpiXFUGJdt49fkV0qpBjWrjV9EMoEJwMo6k/oC+2u8zubEnUObun/2KCqIsTV+TX6llGpQ2MEvIonA34HbjDGFLVmZiNwsIqtFZHVubm5LFtGgxFgXQVecNvUopVQTwgp+EXFjQ/9FY8w/6pklB+hf43W/0LhajDFPGmOyjDFZGRkZLSlvo3yOWDziJag1fqWUalA4vXoEeAbYaoz5TQOzvQFcG+rdMwUoMMYcbMVyhsUntqknENDgV0qphoTTq2cacA2wUUTWhcbdAwwAMMY8AbwNXAzsBEqBG1q/qE3zSSwevPiC+tB1pZRqSJPBb4z5DyBNzGOA77ZWoVrK64glTrTGr5RSjek0V+5CjRp/wPD3NdmRLo5SSnVInS74Y0O3bPjh39ZHuDRKKdUxda7gd8QQp/fqUUqpRnWq4Pc6PHhEg18ppRrTqYK/so1fKaVUwzT4lVKqi+lUwe91xOIRHw60H79SSjWkUwV/vrsPAMPlKwAyFyzlnlc3RrJISinV4XSq4N+RYJ//coZjc9W4l1Z+FaniKKVUh9Spgt+V0o8dwb6c6dgU6aIopVSH1amC/8ezTuWT4CgmO7YRgy/SxVFKqQ6pUwW/x+2EwdOJlwomyM5IF0cppTqkThX8AHsSJxIwwjSnntRVSqn6dLrg97qSWG+GMK3GCV6llFLVOl3wOx3wSXA042QXcZRHujhKKdXhdL7gF2FbcAAuCTJQjgBQ7gtEuFRKKdVxdLrgH9E7mb2mJwCZcgiAgwVa81dKqUrhPHN3kYgcEZF6O8eLyDkiUiAi60I/97Z+McP39dP6c/Ul5wLVwa+UUqpaODX+Z4FZTczzsTFmfOjngZMvVsuJCPPOHElpTHpV8H+yM4+dR4oiWSyllOowmgx+Y8xy4Gg7lKVVFcb1J9NxGICfvraJmb9ZHuESKaVUx9BabfxTRWS9iLwjIqNaaZknpTB+gDb1KKVUPVoj+NcCA40x44DHgdcamlFEbhaR1SKyOjc3txVW3bDihIH0kmMkUdqm61FKqWhz0sFvjCk0xhSHht8G3CLSvYF5nzTGZBljsjIyMk521Y0q7zsVgBmOdYAhnYI2XZ+qx6vfhi//GelSKKXqOOngF5FeIiKh4cmhZeaf7HJP1tTpszhsUpjlXMWtzn+wxnMLFB+JdLG6jtKjsH4xbF8a6ZIopepwNTWDiCwGzgG6i0g28DPADWCMeQKYA9wiIn6gDJhrjDFtVuIwicPJssBpzHEu52LnKjvy+H5I7BHZgnUV+aGb5BUejGw5lFInaDL4jTHzmpj+O+B3rVaiVvRyYDrXut6rHlF0AJgUsfJ0KVXBfyCy5VBKnaDTXblb0103zuXDwPiq18ECDaF2k7fD/i7MiWw5lFIn6NTBf9awDF7v+V2e8V9E0AgBDf72U1njLzsKvrLIlkUpVUunDn6AA65+/Nx/DQdIx2izQ/vJ3wmIHdbvXakOpdMHf+Vp5kMmDYr0RGO7CAYhfxf0Hmtfa/Ar1aF0/uAP/T5kUhEN/vZRsB8CFTDobPtag1+pDqXTB3+lIyYVZ7EGf7uobN8fNN3+1hO8SnUonT74Ky8pOGRScfhKoLwwwiXqAiqDv9dYiEu1RwBKqQ6jyX780a6yqWe7GWAHVjxua6DFh2HOn8GTHLGyRZwxNqS7D2vd5ebvhNhke7FcykA4tq91l9+QbW/Dp7+DvpNgyLkwZEb7rFepKNPpa/wTB6QC8O/gWPJSx8Pyh2Hdi7Dzfdi/KsKli7DPn4bfZcGej1t3uXk7IH0IiEBqJhxvp+Df+gbs+wQ++wP85QrI/bJ91qtUlOn0wb/gouH89JIRgPCNQ/MIjrwCbvrQTjyyOaJliyhvKSz/lR3+50/grduhopUeVpO/E9JDRxGpA+H4V7anT1s7ugcGngl3bANXLHz6eNuvU6ko1OmD3+10MK5/CgBfmv7cH3sn9JsESb1h5wfw6i1QdizCpYyAbW/Z5q5TLoKD62H1Itj+7skvN2+HbdPvMcK+Ts2EgLd9utIe2wNpmZCYAeO/AeuX2B2cUqqWTh/8YMO/0pLPQycae4yEPf+G9S/B589EqGQRYAyUF8CBL8AVB1c9C3NfAnc87P7XyS9/+a/ssiZcY1+nDLS/j+09+WU3xltid2Spg+zrYRfaHc6BL9p2vUpFoS4R/DE1gr/Cb5scKtKHV8/w+dNQEvE7STds10ew4WU4vAV2fWhvgbDxFfhymZ3enJuhblsKvxoG29+GXmPA7YHhl8DQmTb4T+bGqhXFsPFvMOl6W+sGW+OHtg/+yuWnhYK/32n2d3YXP4+jVD06fa8egBiXnDDuL3uSuAkoGnMdSRufh9+MgOk/gp5jYOh54HQ3vWBfGbjjwi9IeSGU5tmdTI/hEJvU9HvKjsErN9oabcoAezHUqbNg099BHDDqCht633wPHE77nmAQHA3s0/d+bC+uOrYXhp5fPX7wOfbkaO626maa5jqyFUwQMs+qHtetPyBtH/xH99jflTX+hHRIGwL7P2/b9SoVhbpEjT/0nJhaXqs4jdu9t5A99QG4ZYUN+w8fhMVft0Hr99oZAz4oyK5+XakkD349HJb9xL6uW1OuKKp9QtNfAU9Mg8cmwDMz4ckZ1Sc9v/xndc8aY6rfZwy8d68N/0AF5O8AX4kN/ZH/BfHpdjhnDez9j33PqqfsTqzyRG3Ab2+fUOng+urh3uOqh0dcBi4PrDiJO2wf3mR/9xxZPc4VAz1HwVeftny54ThWGfyZ1eP6T4b9K0/cdkp1cV0i+APB2qF83aJV+B0xvBo8C4PDBtXcl+Dbn8DM+2zN94X/hoMbbLg/Mgp+0QsenwRL5ttugl++C+XHbb/xdS/Z+VY8Dnk74fEseKgf/HIgvPYdW0tfv9gG/YyfwqWP2KD67Rj441RYPBdeuBK2vwOLZtmdjzHw0f/C2ufhzNvs7Q/i0yG2m/0Qp90EVz4DZ/0QYpLgixfsUcGGl6H4kP0NsOIx+P1k2zT0/GwbwJXL6DOh+ktJ7GGbaDYsaXnt/PBmW5ZuA2qPHzrTrretLp4rL7SfP7kvxKdVjx/5X/YI6+07T64JS6lORiL1sKysrCyzevXqdllXSYWfUT9bVu+0t75/JqP7dqs9cv1f4dX/gbgU22496yF74jB3uz0hnNTbBmX+LohJhKO7IeizJzVdHtsEM/U7tvlh499sW3rRYUjoDt/60PZv/+oz2+f840fsskygduAOv9T2vJlwDVz+OJTk2lr8fx6BPcvhB19UN+28eos9SR2fbh95iLEnr29ZYXdWR3eBwwVBv53/skchuR8Mm1n7cxcegEfH2R4xlz0a/hcc8NvP+dH/QnJv+Gad5+zu/QSevRi+9hcYeXn4yw3H+iXw9o/sTu+aV2Hw9NrTP3gAPv41zFoIU25p3XUrFQEissYYk3Uyywjn0YuLgEuBI8aY0fVMF+BR4GKgFLjeGLP2ZArV2hJiXVw1qR9/W5N9wrTKk721jPs67P7I1tLHfA0mf6t62s4P7NFA7jY47Vv2aOGt2yFtsH3MYHJf+Npz9gImsFeRvnWb3Rlc+ZQNfYABU+zP+Pm2zznA1jftTuWdu2zoj7/aBrCI3Tkk9oCL/w/85dWhD3DBz2HgGfDuAsDA6d+GlU/Ahz+3oe/y2PdU6je5dnNMpeQ+MPFaWPMcTLoB+ow/cZ66An54cY79vqD6xmw19Z8MnhRY82fbpFRP01uLlBfY7yp9CFzwC8icduI8M35qd9jL7oH0oTDs/BPnUQogGLD/p439fQaD4C+DmIT2K1cbCOfk7rPYRys+38D0i4BhoZ/TgT+GfncoDR3XVPgC9U+Y8RNboz/zttrjh54H33wfNv/D7hASe8EXL8K0W6H/6fbeNK6Y6vknXAM7/gkDp9mgryupV/XwxGvt7x4jwVsMGaeeOL/bY39qSugOE6+xf4xbXoPzf26boj7+tS3PBQ/Ch7+A696EnNWNn7w964e2P/+iWXDdGza06zq82YaoKxY2vWJDf/pd9ihm3NwT53e64ZwFdse05TV7QroljLH/lMbAx/8Hq5+1zW2XPlK72aomhwOu+BP8eZZtphtzlZ2/5jZqTcbAqiftDiZtcO1p65fYHfvg6Xa+XR/Yk9GVlQRvCThjwRn6t9zyhj0SbOn3VbNMTe1si3NhxzIY+/XwOjaAPXeS96WtMBxYa3fuvcZA8RH7t5i7Hb5aYc+TpQy056lSB9m/7Zy10K2f7RDgirEVpvICu+70obDmWduEN+hs24nCFWs7Q/jLwRlTfQScPsSG9ZGtEJcGsYn2yLUkt/rWIf5y+71WFNr1eItsWY2xR+pBvz2PdnS37eI8eLrdFoc32e3lcEJBDiT2tMsM+mDAGfb/6PAmezRdfMSW3V9uy+UtsdPj0uz8AX/otw8mzLdNtREUVlOPiGQCbzVQ4/8T8C9jzOLQ6+3AOcaYRq/Yac+mHoAFf99Q3Ye/hkXXZzG8VzJ9UprROycabH8HPnnUHjHUtwNpTPEReOZ8e/FT91NgzJUw4Vr7D7b8YfjXQ7aHUUKGPW+R2BOudsJKAAAXHUlEQVS+/Z/GwyXgh6fPtU1e31tll3Vsr/3HSO5j5zGmdm+k3O02ALzFkL3G3mNpzFX2YrBtb9neQ6OugNO+2fRnKjoUOmfynD06OON7tiw7/gnj5lWHbUsc2mjDyh1nw/rla2D0lTBnkZ3u99rA+ONU+3rM12zI7VhmgyXzLNvsV5hjKxLj5kLfifD3b9mwPPViWxmo3AmfcqHdyW78G5Tm2zCLSbRhG5Ngv9eKQsgYbm9f4S21FZP0IfY7Ls2DpD426FIzYfUz9nxW8SEbaEG/PYLc94kNxF5j7bj8XbZc5cdtkH75rt02NYnT7qyaJICxy8HYay5qcifYq77zvgwdsVbY4Kw5PT7NdrzA2B1LSZ4N3qRedkfRe5wNYJfH7jxiEmyTbWyy/QwOlw11h9tOyzjV7jD2fmLv4dVjpG06DfrsMosO2f8Hl8f+fx3bC71G28+cmGGPGFwe+15njP1O/eV2h+Bw23U53TB6Doxv9FHmjX9zrdDU0xrB/xaw0Bjzn9DrD4C7jDGNpnp7B//REi+/eW87w3sl89PXNlWNH9uvGxuyC/j3j85hYHp0H761qgNfwMvX2n/M/B22piYO23Q0/FIbLGBrQrMWwikXNL3MnDXw1Hk2kAr2V4dGzX+w3uMBY4Nm36d2nZW1vb4TYcf79h/2rNvhzDua32z04lWwb4U9Qln1pC1H5Y6gPiV5NvCWP2zLnzbYHtkNPQ/iu9teUm/+wH6moTNtL6uig/Yf/7Sb7Hmene+HatECk2+CT39vgyvrBlsGY2wTWPdTYP9n9nqKoN8GVOZZtjZd88rnjBGQu9WGfUKGnVazKc/hst9bwAsDptrl7/+s4e8kqbfdcWWcaq9pSc0MhdpYu4M/uMGWv1tfyF5tQ7C80H7eIaHeab3H2zLkrLZdeH2l9rsaOM1uu6O77XbM3W4/V5/xNtQTe9laur/Cjg9U2L+9lAH2p5Lfa78Td1x1k4zDYbdPwGsrD4FQjTomvnl/E1Em6oJfRG4GbgYYMGDApH372unmXTW8vi6HW5esO2F8vSd5lQ2NbUvtOQOnG8Z9w9ZmG7pOoCkb/mZroSn9YdR/28Pyza/acEjIgLztgNh/7B7DYeb99rDc4bQh76+wZarb3BWu4/th8Tw4vNEGXFJve7O+yru0GgMYG8yDz7EnzcEe1Qw7H/J32yCuGbS9x9saZOEB6H4qnH4zvPF9G8BpQ+wOa9tS2zR43r02rCp3Znk7bG2z8qgHbKhuec3WYitPVm9/19ZGj+6yO6Ah58Lk/6kOuWDAhq231H43x/fbco6fb7/L41/Z5x8Hg6Gjgj02aHe+B9Nusx0ZwNbwPSm2dpvUq/XOx6hW01GCPyqaeiq9siabO/+2/oTxPzhvGHecf0q7l0dFgL/C1uIzTrVHGp8+bgNThKrnBB/4wl71O/xSG/gj/6s6HL2ltnuqrxTKjtvaeuW0SlvesLe7rjyf4vfaHacGqTpJ7dKrJwxvAN8TkSXYk7oFTYV+JJU3cDL3sQ92aPB3Fa7Y6l5NiRlw/gMnzhPw24u/Bkyp3YMKbC176HmNr6Nut9W2OpmsVAuE051zMXAO0F1EsoGfAW4AY8wTwNvYrpw7sd05b2irwraGhoJfqVqcrvq7hyrVCTQZ/MaYRk8/G9tW9N1WK1EbmzDAHpI/fW0W/+/1TRwsqG6rDQYNDoceiiulOrcuccuGmiYNTGPz/Rcyc2RP/HVu5eANtMPDQpRSKsK6XPCDvZIXILeootb4+97YTOaCpSfc20cppTqTLhn8Dam8wKvCr+cBlFKdV5cO/imD0+od763v/j1KKdVJdOngf+Gbp7P+ZydecVrvjduUUqqT6NLB73I66Bbn5vyRPWuN1xq/Uqoz69LBX2nuaf1rvX5zwwGuXbRK2/qVUp1Sl3jmblOS42rfhvbhd7cDcPB4OZnd9cZtSqnORWv8QLe4+u8/XlDmq3e8atru3GJuXfKFNpsp1QFp8AOJsfUf+Mz+/SccL9UHdbfED5Z8wevrDrD1YBs9Z1cp1WIa/EBagr2B1swRPU6YtivX3jP+wPEyCst9fLjtcLuWLVodK9GjJaU6Km3jBzxuJzt+cRFOEf7y2T5+9sbmqmmPfrCTDdnHOV7qwyEQNLBiwbkcLixnQ3YB152RGbmCd2ClXvtg90fe/5Jnb6jn8Y1KqYjRGn+I2+nA4RDmnz6g1vjlX+ZyvNTWXivv5LBiVz5X/GFFrR2Eqq2ybf9f23N5fV1OhEujlKpJg78Ol9PBJwvO5bszhjQ4T80HuYTzIJu28tnufJ7+eHfE1t8YX437He3OLWnXdecXV/D8p3v5y6d723W9SkULDf569E2J40cXDg9r3kF3v83u3OKmZ2wDc5/8jAeXbm2VZbX2jeliXdV/Wu25awwGDZMefJ97X9/M/3tdj8iUqo8Gfyt4a8NBXl+X02jXRWMMr6zJpqDMRyBoOFJYzsPvbmPVnqPkF1c0+D6AdzYe5O5/bGxwevAkQntffgmZC5Yy5J63GXrP2y1eTk2+QJCicj+nZaYCUBZq728PR4oa/y6VUnpyt1Gvf3cas3//SZPz/ea9LwHIvrCMmSN6MiQjgSNFFVz1xKc8Nm88QQMpce56n/X7h3/tomdyLCvvmdng8m95cS0AD/33mKpxNZt4Srx+kjz1X4vQkI+2H+FIYTlPLq9ejj9oKCz3kdzMZdV1rMR2gb18XB9255awPrvgpJbXHDuOFNV6bYxB9Dm3StUS7sPWZwGPAk7gaWPMwjrTrwd+BVSexfudMebpxpYZqYetN1dhuY/EGBef7c5nY04B00/NYNZvP2719Vw9ZQD3XjqKmBpNJJtyCrjzb+vZdsiG2Z+umcTYft3ISIxl6E/eqfX+i8f04g/zJ1W9Lij14XYJ8TH179szFyxtsCx7F15yMh+Fj3fkcs0zq/j9Nyby3ZfsTuulm07njKHdT2q54fjGU5+xYld+1eufXjKCm84a3ObrVaq9tMbD1psMfhFxAl8C5wPZwOfAPGPMlhrzXA9kGWO+F+6KoyX46/P4BzvolxbHU8v3kOhxsWrP0VZd/tTB6Vw6rjc/eXVTvdOTYl0UVZzYfPL8jZP555ZD/OTikYy4910Gd0/g/tmjeG7FXs4+JYOtBwv5+ezRACfsOGq684JT+N65w1pc/sqdyuJvTWHeU58BkDUwlVduOaPFywxHIGgY/bNllNXzXOVfXzWOKyf1a9P1K9Ue2iv4pwL3GWMuDL2+G8AY81CNea6nCwV/Xc+t2Nusrp1up+ALdKynfD0wexT31jgZ+und59K7W1yzllFQ6qOgzMfZv/oIgGW3nc0Vf/iEUq8N4vfvOJshGYlt1vTy/cVf8Ob6A3x7+hASY50s23yYjTm2menUnkksu/3sNlmvUu2pNYI/nDb+vsD+Gq+zgdPrme9KETkbe3RwuzFmfz3zdErXnZHJzJE9SY13M/fJz8gv9vLxj2ew/1gpD7+7nf+a0JeB6fEcLiznrGEZgD1hW9l2D7aWfbTEx7Mr9jBrdC/e3nio3nV1T4yhT0ocG0Lt5ou/NYUPtx3mqY/3tKjsz1yXRffEWAZnJNQK/qkPfcjMET2ZM6kvZw3LqHpcZX2MMby48it++lrtI5T0xBjev2M6Zyz8EICZv1nOpIGp9EyO5ccXDm/VG+Ct2JnHm+sPADD/9AH0T4snEKQq+LcfLqLcF8DjdrbaOpWKVuHU+OcAs4wxN4VeXwOcXrN2LyLpQLExpkJE/gf4ujHm3HqWdTNwM8CAAQMm7du3r/U+SRSq8AfYcqCQCQNST5hWUOZjxc48Asaweu8xvjtjKBlJsfUuxx8IcuB4OQvf3drgDqOm+y4byaJP9vLLK8cydUh6rWkfbjvMjc82fiR263nDcDqEi0b3ok9KHI99sIM/La99PcEz12Vx3gj7nIPducWc++t/15rePTGGM4Z052tZ/flo+xHOH9mT8f1TwgrmonIfX3x1nAkDUgga+MHiL/j3l7mAPQ9y4ahegL3NRuVOB2B032SW3DyV+NA6Cst9bMop5MxhbX/uoT4f78jlUEE5V2X1b3pmpUI6TFNPnfmdwFFjTLfGltuZmno6imDQsOiTPfiDhnOH9+DPn+zhnotHkORxs37/cYb3TiLW1XSwFpT5GHf/P1tcjp9dNpIbpg2qNe5QQTlTHvqgyff2TYkj53gZkwelcUrPRL48VMyQHonEOIXnP9tHU30RNt53Qa0eTvvyS/jeS19U1fzrc+cFp/DPLYdxOoT/OXsws0b3BmzvpJR4d6NNU8YYfAFT66R85fgKf7DRHVnluZC9Cy8hEDQI4HDUXtfREi+r9uQztl8KfVKa1/SmOqf2Cn4XtvnmPGyvnc+BbxhjNteYp7cx5mBo+ArgLmPMlMaWq8HfsZV6/cS5nRSW+XlzwwEKy330TYnj/a1H6N3Nw3tbDrMnz16R2zcljumnZrA7t5gfXXgqkwbW/yzjoyVenv1kD/uPlfGfnXnktlKf+/5pcfRK9vDSt6bgdtZ/acqmnAJW7TnKwne24Q00fqvoGKeDjKRYco6XnTDtyon9cAgkelzEuZ28uPIrCsp8nDm0O3nFFcTHOFn71XEAPG4HEwekUlTuZ09eCcWhE/J9U+Lomxp3QqeA1Hg3p/RMYmVo/KSBqazZd6xq+lWT+lFQ5uPiMb3pFu/mzpfXY4AzhqTj9QfpkRzLsB5JLN1wkG+dPZhTeybxwbbDCNA3NZ5ducV4XA4GpMeTlZlGcbmfoyVevIEgThFEQBD2HS0hxung9EHplPr8rN9/nNxiL1dM6EtCjBN/0OByCCLC+v3HGZAWX2sHWVTuIzHWRV6xl/SEmKqdWYU/wLaDRYzt161qXmMM/qDB6w9S6g3UOqrdf7SUfqlxTe54iyv8uBwODhWWM6hG86EvEKTMFwi7e/KxEi9JHhclFQG6xdd+T6nXT1G5n7gYZ73LK/X6iXU5cdaz406MdVVVDIrK7e1fmtv9uqZ2Cf7Qii4GfovtzrnIGPMLEXkAWG2MeUNEHgIuB/zAUeAWY8y2xpapwa8qbT9UhC8Q5KujpZwxJJ1duSUM75VEwBhW7MyjsNzP8VIvbqeDy8f14bkVe7np7MG8v+UwF4/p3ex2+9V7j/Lauhz2Hy3jkjG9eXPDAT7dlY8/aGvuQzISG72dtMsh+Fv5SudoEx/jpMIfrLriO9blwOmQWuMqxTgd+ILBqqO1pFgXLqfgdjooqfBTXuM9qfFuSrwBMOANBEmKdRE0hjJfgNT4mKpgNdhgLyzzUXN18TFOUuLsTii3qAJvIEj3xFhcDsEh9pYsQWMoLPPhcjpwOWw5AkHDocLyquX07uYBbE+xoDHkFdtrUxwCPZM95Bd76dktFkEIGkPO8TI8LicZSbGY0LXqxkD2sTJEoFeyB4cIecUVBI3hxxcO51tnt6ybcbsFf1vQ4FfRIBA0VFbivIFgVVOZ1x8kaAwiUO4L4gsESYhx4XBAMGhr+yLC4cJygsbQI8nD8VIvaQkxlPkCOMTuPGJdNnQ8bieF5T7cDgf7jpaQX+wlNT6GQd0T2JNXwu68YrrFuembEsfhwgoKynwUlfsoLPfTLc7N4cJyTumZRIU/gFOEEm8Arz9IzvFS/EFDssdNblEFKfFuPG4nEvoMIuB0ODhaUkFKfAxf5ZdS4vWTlhBD35Q4jpX6OHC8jP5pcQj2i4hxOThW6iXG6eBAQTn9UuMoqfAT43Tgcdtab0mFn8/3HmVsvxQSYl24nYIxsPlAQejEu6n63IGg4UhROQVlPk7pmUSsy0FxRYCdR4oY2iOR+BgXsaF1GmMvNHQ7hVJvgBingx1HihnVJ5lDBeX0SPbgC9idTG5xBUXlPob1SATsTRb9gSAOsWEdMJAQ46w64ln71TH6p8UTNJCRGItDwOkQHA5hb14JG7Jt2Uf2TsbpgAq/XZYI7Dxit0/3RHvEUlnv35BTQL/UOHokxRIIQl5xBU6HMHt8H2aP79uiv0kNfqWU6mJaI/j1Xj1KKdXFaPArpVQXo8GvlFJdjAa/Ukp1MRr8SinVxWjwK6VUF6PBr5RSXYwGv1JKdTERu4BLRHKBlt6eszuQ14rFiTT9PB2bfp6Orat9noHGmIyTWUHEgv9kiMjqk71yrSPRz9Ox6efp2PTzNJ829SilVBejwa+UUl1MtAb/k5EuQCvTz9Ox6efp2PTzNFNUtvErpZRquWit8SullGqhqAt+EZklIttFZKeILIh0eSqJSH8R+UhEtojIZhG5NTQ+TUTeE5Edod+pofEiIo+FPscGEZlYY1nXhebfISLX1Rg/SUQ2ht7zmDT2TLrW+1xOEflCRN4KvR4kIitDZfiriMSExseGXu8MTc+ssYy7Q+O3i8iFNca367YUkRQReUVEtonIVhGZGs3bR0RuD/2tbRKRxSLiiabtIyKLROSIiGyqMa7Nt0dD62ijz/Or0N/bBhF5VURSakxr1vfekm3bIGNM1PxgH/24CxgMxADrgZGRLleobL2BiaHhJOxzikcCDwMLQuMXAL8MDV8MvIN9WM8UYGVofBqwO/Q7NTScGpq2KjSvhN57UTt8rjuAl4C3Qq9fBuaGhp/APmYT4DvAE6HhucBfQ8MjQ9spFhgU2n7OSGxL4DngptBwDJASrdsH6AvsAeJqbJfro2n7AGcDE4FNNca1+fZoaB1t9HkuAFyh4V/W+DzN/t6bu20bLWtb/qO1wR/7VGBZjdd3A3dHulwNlPV14HxgO9A7NK43sD00/CdgXo35t4emzwP+VGP8n0LjegPbaoyvNV8bfYZ+wAfAucBboX+gvBp/yFXbA1gGTA0Nu0LzSd1tVDlfe29LoBs2KKXO+KjcPtjg348NPFdo+1wYbdsHyKR2ULb59mhoHW3xeepMuwJ4sb7vs6nvvSX/e42VM9qaeir/2Ctlh8Z1KKFDrQnASqCnMeZgaNIhoGdouKHP0tj47HrGt6XfAj8GgqHX6cBxY4y/njJUlTs0vSA0f3M/Z1sZBOQCfxbbdPW0iCQQpdvHGJMD/B/wFXAQ+32vIXq3T6X22B4NraOt3Yg98oDmf56W/O81KNqCv8MTkUTg78BtxpjCmtOM3SVHRTcqEbkUOGKMWRPpsrQSF/Yw/I/GmAlACfYwv0qUbZ9UYDZ2h9YHSABmRbRQraw9tkd7bXMR+QngB15s63WFI9qCPwfoX+N1v9C4DkFE3NjQf9EY84/Q6MMi0js0vTdwJDS+oc/S2Ph+9YxvK9OAy0VkL7AE29zzKJAiIq56ylBV7tD0bkA+zf+cbSUbyDbGrAy9fgW7I4jW7TMT2GOMyTXG+IB/YLdZtG6fSu2xPRpaR5sQkeuBS4H5oR0NNP/z5NP8bduwtmqDbIsfbK1tN7aWU3niY1SkyxUqmwDPA7+tM/5X1D6R9HBo+BJqn6xaFRqfhm2LTg397AHSQtPqnqy6uJ0+2zlUn9z9G7VPMH0nNPxdap9gejk0PIraJ7F2Y09gtfu2BD4GTg0N3xfaNlG5fYDTgc1AfGh9zwHfj7btw4lt/G2+PRpaRxt9nlnAFiCjznzN/t6bu20bLWdb/qO10R/8xdgeM7uAn0S6PDXKdSb2kHEDsC70czG2re0DYAfwfo0/SgF+H/ocG4GsGsu6EdgZ+rmhxvgsYFPoPb+jiRM4rfjZzqE6+AeH/qF2hv4QY0PjPaHXO0PTB9d4/09CZd5OjZ4u7b0tgfHA6tA2ei0UFFG7fYD7gW2hdf4lFCJRs32AxdjzEz7sEdk322N7NLSONvo8O7Ht75WZ8ERLv/eWbNuGfvTKXaWU6mKirY1fKaXUSdLgV0qpLkaDXymluhgNfqWU6mI0+JVSqovR4FdKqS5Gg18ppboYDX6llOpi/j+zkGglSm7eJQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "complete - train time: 24187s, best epoch: 152, best loss: 0.970739, best accuracy: 79.13%\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.utils.plot import Ploter\n",
    "import numpy as np\n",
    "import time\n",
    "import math\n",
    "import os\n",
    "\n",
    "epoch_num = 300   # 训练周期，取值一般为[1,300]\n",
    "train_batch = 128 # 训练批次，取值一般为[1,256]\n",
    "valid_batch = 128 # 验证批次，取值一般为[1,256]\n",
    "displays = 100    # 显示迭代\n",
    "\n",
    "start_lr = 0.00001                         # 开始学习率，取值一般为[1e-8,5e-1]\n",
    "based_lr = 0.1                             # 基础学习率，取值一般为[1e-8,5e-1]\n",
    "epoch_iters = math.ceil(50000/train_batch) # 每轮迭代数\n",
    "warmup_iter = 10 * epoch_iters             # 预热迭代数，取值一般为[1,10]\n",
    "\n",
    "momentum = 0.9     # 优化器动量\n",
    "l2_decay = 0.00005 # 正则化系数，取值一般为[1e-5,5e-4]\n",
    "epsilon = 0.05     # 标签平滑率，取值一般为[1e-2,1e-1]\n",
    "\n",
    "checkpoint = False                   # 断点标识\n",
    "model_path = './work/out/ssrnet'     # 模型路径\n",
    "result_txt = './work/out/result.txt' # 结果文件\n",
    "class_num  = 100                     # 类别数量\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 准备数据\n",
    "    train_reader = paddle.batch(\n",
    "        reader=paddle.reader.shuffle(reader=paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "        batch_size=train_batch)\n",
    "    \n",
    "    valid_reader = paddle.batch(\n",
    "        reader=paddle.dataset.cifar.test100(),\n",
    "        batch_size=valid_batch)\n",
    "    \n",
    "    # 声明模型\n",
    "    model = SSRNet()\n",
    "    \n",
    "    # 优化算法\n",
    "    consine_lr = fluid.layers.cosine_decay(based_lr, epoch_iters, epoch_num) # 余弦衰减策略\n",
    "    decayed_lr = fluid.layers.linear_lr_warmup(consine_lr, warmup_iter, start_lr, based_lr) # 线性预热策略\n",
    "    \n",
    "    optimizer = fluid.optimizer.Momentum(\n",
    "        learning_rate=decayed_lr,                           # 衰减学习策略\n",
    "        momentum=momentum,                                  # 优化动量系数\n",
    "        regularization=fluid.regularizer.L2Decay(l2_decay), # 正则衰减系数\n",
    "        parameter_list=model.parameters())\n",
    "    \n",
    "    # 加载断点\n",
    "    if checkpoint: # 是否加载断点文件\n",
    "        model_dict, optimizer_dict = fluid.load_dygraph(model_path) # 加载断点参数\n",
    "        model.set_dict(model_dict)                                  # 设置权重参数\n",
    "        optimizer.set_dict(optimizer_dict)                          # 设置优化参数\n",
    "    else:          # 否则删除结果文件\n",
    "        if os.path.exists(result_txt): # 如果存在结果文件\n",
    "            os.remove(result_txt)      # 那么删除结果文件\n",
    "    \n",
    "    # 初始训练\n",
    "    avg_train_loss = 0 # 平均训练损失\n",
    "    avg_valid_loss = 0 # 平均验证损失\n",
    "    avg_valid_accu = 0 # 平均验证精度\n",
    "    \n",
    "    iterator = 1                                # 迭代次数\n",
    "    train_prompt = \"Train loss\"                 # 训练标签\n",
    "    valid_prompt = \"Valid loss\"                 # 验证标签\n",
    "    ploter = Ploter(train_prompt, valid_prompt) # 训练图像\n",
    "    \n",
    "    best_epoch = 0           # 最好周期\n",
    "    best_accu = 0            # 最好精度\n",
    "    best_loss = 100.0        # 最好损失\n",
    "    train_time = time.time() # 训练时间\n",
    "    \n",
    "    # 开始训练\n",
    "    for epoch_id in range(epoch_num):\n",
    "        # 训练模型\n",
    "        model.train() # 设置训练\n",
    "        for batch_id, train_data in enumerate(train_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = train_augment(image_data)                                                        # 使用数据增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "\n",
    "            label_data = np.array([x[1] for x in train_data]).astype(np.int64)                        # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                             # 转换数据类型\n",
    "            label = fluid.layers.label_smooth(label=fluid.one_hot(label, class_num), epsilon=epsilon) # 使用标签平滑\n",
    "            label.stop_gradient = True                                                                # 停止梯度传播\n",
    "\n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label, soft_label=True)\n",
    "            train_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            # 反向传播\n",
    "            train_loss.backward()\n",
    "            optimizer.minimize(train_loss)\n",
    "            model.clear_gradients()\n",
    "            \n",
    "            # 显示结果\n",
    "            if iterator % displays == 0:\n",
    "                # 显示图像\n",
    "                avg_train_loss = train_loss.numpy()[0]                # 设置训练损失\n",
    "                ploter.append(train_prompt, iterator, avg_train_loss) # 添加训练图像\n",
    "                ploter.plot()                                         # 显示训练图像\n",
    "                \n",
    "                # 打印结果\n",
    "                print(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\".format(\n",
    "                    iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "                \n",
    "                # 写入文件\n",
    "                with open(result_txt, 'a') as file:\n",
    "                    file.write(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\\n\".format(\n",
    "                        iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "            \n",
    "            # 增加迭代\n",
    "            iterator += 1\n",
    "            \n",
    "        # 验证模型\n",
    "        valid_loss_list = [] # 验证损失列表\n",
    "        valid_accu_list = [] # 验证精度列表\n",
    "        \n",
    "        model.eval()   # 设置验证\n",
    "        for batch_id, valid_data in enumerate(valid_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = valid_augment(image_data)                                                        # 使用图像增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "            \n",
    "            label_data = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64) # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                       # 转换数据类型\n",
    "            label.stop_gradient = True                                                          # 停止梯度传播\n",
    "            \n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算精度\n",
    "            valid_accu = fluid.layers.accuracy(infer,label)\n",
    "            \n",
    "            valid_accu_list.append(valid_accu.numpy())\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label)\n",
    "            valid_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            valid_loss_list.append(valid_loss.numpy())\n",
    "        \n",
    "        # 设置结果\n",
    "        avg_valid_accu = np.mean(valid_accu_list)             # 设置验证精度\n",
    "        \n",
    "        avg_valid_loss = np.mean(valid_loss_list)             # 设置验证损失\n",
    "        ploter.append(valid_prompt, iterator, avg_valid_loss) # 添加训练图像\n",
    "        \n",
    "        # 保存模型\n",
    "        fluid.save_dygraph(model.state_dict(), model_path)     # 保存权重参数\n",
    "        fluid.save_dygraph(optimizer.state_dict(), model_path) # 保存优化参数\n",
    "        \n",
    "        if avg_valid_loss < best_loss:\n",
    "            fluid.save_dygraph(model.state_dict(), model_path + '-best') # 保存权重\n",
    "            \n",
    "            best_epoch = epoch_id + 1                                    # 更新迭代\n",
    "            best_accu = avg_valid_accu                                   # 更新精度\n",
    "            best_loss = avg_valid_loss                                   # 更新损失\n",
    "    \n",
    "    # 显示结果\n",
    "    train_time = time.time() - train_time # 设置训练时间\n",
    "    print('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}'.format(\n",
    "        train_time, best_epoch, best_loss, best_accu))\n",
    "    \n",
    "    # 写入文件\n",
    "    with open(result_txt, 'a') as file:\n",
    "        file.write('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}\\n'.format(\n",
    "            train_time, best_epoch, best_loss, best_accu))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "infer time: 0.010264s, infer value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGylJREFUeJztnWlsXOd1ht9zZ+Um7pQoURIteZUVW04c17HjVFmcOGkAJ0VhJGgDA3WWAgnaoPljuECbAv2RAk2CoghSJKhrB0jjpHFSu47T2HGdOnYTWbIta7MsibQW7hSX4XAWznK//phhyuF7eDniSBQpnwcQxDm8c+937/DMved857yfOOdgGIaOd7kHYBhrGXMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAIwBzGMAGpyEBG5R0TeFJFTIvLgxRqUYawVZKUz6SISAnACwN0ABgDsB/Bp59yxpd7T0dHhent7V3Q8Yzn4c8zPzZEtlU6TrbFpg7rHcDhc+7BWgK/YisWCuu3cXJZsoTB/7+dylduNjYwjMZ2U5cZSyxW4DcAp51w/AIjIYwDuBbCkg/T29uLAgQM1HNJYkiI7w8jZPrLte/lVst31oXvUXba1d9Q+rmUoKrZ0ka3J2Un1/f19b5Cttb2BbGfPnqx4/eefe6iq8dXyiLUFwLkFrwfKtgpE5PMickBEDoyPj9dwOMNYfS55kO6c+45z7lbn3K2dnZ2X+nCGcVGp5RFrEMDWBa97yrYLwqqJLxxfeR6X/BTZkmP9ZHv+yZ/wdkl+jgeAP/nsZ9mofF6+r3yGylevAz/y55X3Dg2fJdvk9IA6xuFzR8nWf/I82RIzlddnLptS97eYWu4g+wFcIyJXiUgUwKcAPFnD/gxjzbHiO4hzriAiXwLwCwAhAA8759idDWMdU1Mezzn3NICnL9JYDGPNYTPphhHA5ZkJWgaRZedv3jZoKQxPlNmDYpLfm+G0eoOfI9vE8Ih67NGRUbKFhL9Tm1uayRaJRsjmK0G6czwtGOa3Il/MqGNs39hOttFxDtKH+4Yq95fPq/tbjN1BDCMAcxDDCMAcxDACMAcxjADWZJC+WmhVo87nor/CFAd9mcQsvzfKRXIbtmzWD64Eu6IErJ7Ps+Yzw+fIdvrIb8n21hvHeX9eVNkfz1wDwK+efpxsrZu3ku2OO+/iN4e5QnhiOkG2uVlOEGSzY2RzBU5CAMDYJFcLTE3z5+X8xde7ukSQ3UEMIwBzEMMIwBzEMAIwBzGMAN7WQTp8npE+f4oD27FXXiRbepIDzpEcf99ce9de9dDX3Hwr2bwIfxyHjx4m22vPP0+2pBK4z4zxTHgkHCNbdmKIbADw/M/OkO2G3/8I2d7zvg/yPud4xn5qjPfXv59L+UaHuBOyffs2dYxpn8vW82m+jlGvq+K1VPmnb3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIoKYsloicBpBESd6o4Jzj1MwaxmW5rGTiTc6gYHqGTG0hRcjM48xN/wvPqscOOy51iG/mTM33fvyfZDt64CDZdrRymUubx2NsUDJlxZDSgAGg/wRnt1488WOydffcSLa7bruBbOPH/5dsrz/zU7LNTbMARWpwlzrG+l3vYlsd63k1XdVa8Toaq04+4WKked/vnOPiF8O4ArBHLMMIoFYHcQCeEZFXROTz2gamrGisZ2p1kPc6594J4KMAvigi71u8gSkrGuuZWmV/Bsv/j4nIT1EStH7hgnZyGfUZvCj3RjR2cf/G+MBbZMuOs9JfQ5T7OWay+gke/61SvtK6nWzPPPMSb5fk3ogmr5ttrXGypeY4cD9+VhdtGEmxZMTABAfQ33/kX3m7g11kS59j4fKGIpeKxOq4HGYuxar0ALC9kQNyb+PVZMtK5Wcd0pQhFFZ8BxGRBhFpmv8ZwIcBHFnp/gxjLVLLHWQjgJ+WJXrCAP7NOfdfF2VUhrFGqEV6tB/AzRdxLIax5rA0r2EEcPn7QTTpwGoD96VWTqjy/U5ZYmzTO/immJ+dJlvf2TfJlp7kNHYuVqce+8QJXhkp1cjqgeE8n+TMBK+2lFBWVYpv58B9ZoqD7ENn9CB9PMdJjKZmVlE8e+p1su2b5CUVrungwDga4fObnmNbU5d+HYeHuA9mQ30bH6dtkQKjVLfsht1BDCMAcxDDCMAcxDACMAcxjAAue5CuxUpKJfgS772A9Q2VJRVEWR8vEuPZ5y233cn7UyZih1/lWe8eRYkQACbOs2DEoX2vka0uzIF7RxMHz3vv4jH+3s1cIv5P3/oW2ZIZLtMH9GuhKRymlVnu2FZelsB3HLiPjnErQbh1I9mkQS9Tev0otyckXmHhje4dOypep2b4uBp2BzGMAMxBDCMAcxDDCMAcxDACWPUgffGi85qH+krwnc1x/3hUmQkH9HX0PG16XQncC8r0fN8kdxRPKQHs3LW7yXbju+5Qx5g/y7PhP/rZL3m7DJeDf/KevWT7w49/mGwnT/HSAGMpTg7kXEgdY8TxttEwb9sU52vR0MJBdSLP59KwkWf7XR0vnTAwri9/UMxwEiOnaAg8/2RloXlymqsjNOwOYhgBmIMYRgDmIIYRgDmIYQSwbJAuIg8D+DiAMefc7rKtDcAPAfQCOA3gPucc11EvwncOc/nKWdu40hc+k+b1/17av49sGxob1ePccuNNZGuqqydbscj92YPjLJb2qxc5eH7rLK/rN6fMSMc296pjLCR5VnnsDC8PMJvka7Gzl2fnw+CAejrBwWrO5yC7UNRWawT8NAfGnuMSglCcP8OJSf5zGB3jZEedsq5jQzMnZBpbeDsAaFKSBnVhTrRs7WipeN13Tl/yYTHV3EEeAXDPItuDAJ5zzl0D4Lnya8O44ljWQZxzLwBYnJO8F8Cj5Z8fBfCJizwuw1gTrDQG2eicGy7/PIKSgIPKQuG48yYcZ6wzag7SnXMOSze/VgjHdZhwnLHOWOlM+qiIdDvnhkWkGwCv/K4gAsiioGpmloPQ/QdfJdvZ4UGyxaIsMAYAnW0sJnZd706yJWYmyHbwIAu6DZ8+RraRsxxwjk3xuRw8zIrmAHBbz/Vk27GJv0Cm2ri/urmDZ5/PDXFf+fAwB6KpJAfPLY16v3dqloP0mSmuANjR1UO2xjj/aaXrFGX5AidKiikeY9HTy9NzrVxWjzAnLJqbK88xHKru3rDSO8iTAO4v/3w/gCdWuB/DWNMs6yAi8gMAvwFwnYgMiMgDAL4G4G4ROQngQ+XXhnHFsewjlnPu00v8itf+NYwrDJtJN4wAVrXc3flAca4ygHpp38u03StHD5Ft5/UcCA6dS6jH+Y+nniPbxz+WJ1vfaRZv6zvHSu5eiMu5J5VZ4cGB02SLF9+tjvEdvb1k+7M//QzZtNnwnS0s3jY0xEmMk4c5uZCc4FR7c7sS6AIoFpQydmXSfUtrE9mcshyd+PzmkMcJ0FBIaUPI8+cHAGlF1C8U5pn9ol+ZDHDQqwcWY3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIYFWzWEW/iORsZebpv1/gXov2zVwqMpfl/okz/bpsvyiZkZcPserhESVbJsolCWmXKcw9C3s/uIdsXa1cKgIAhTRneXZfdx3ZPGW5goFfcJau7jxnc+5u4nUCN13LvTIHxofJBgDH67j3o7eHy1w6lbKSbJbLVLS+E9/n7JS2fmAsrJfD5JSelajS++NF9LKk5bA7iGEEYA5iGAGYgxhGAOYghhHAqgbp4gkiDZXBUnMbCy8MDrKk/aHXeQn2M6e4/wIAuns4oGvfxCUbvs+9CFOTvM+IEvT37lAC4M1ccpGZ00skclkO0ouK6EPmNJeQpE9zUJ1IcDBfp5SkvHsbl+x0x3jcALBhgvtJwq0snuBH+Dq6IgfaogTkxTwnX0SLpxWxidI+ufejMMf7jHqL329rFBpGzZiDGEYA5iCGEYA5iGEEsFJlxa8C+ByA+eaCh5xzTy+3r1Q6i32vVfZgFBXp/VCIh/VWP/dpDA7qQXpjK4sfFIutZEsmeW09LUi/Sglsuzo5SB8YOEG21rAusx+5kRMJ4QRL+Z87eJRsR2d4GYGfHePtEj4Hqy1xnmX+8HW3qmO8I8oKjudGT5Mt1MwBeaGeezrySvDsfE5MOJ8/fy3wBoBiUZmJd8qM/eKlMqpc33KlyooA8E3n3J7yv2WdwzDWIytVVjSMtwW1xCBfEpFDIvKwiPDzS5mFyoqJKlf1MYy1wkod5NsAdgLYA2AYwNeX2nChsmJzS8tSmxnGmmRFM+nOudH5n0XkuwCequZ9c7kM3jp9uHIAilR9VzuXu4vSZB+v02dXP/SBj5Dt+l07yFacYwXHrjZFOr97G9k623j2ecdWLlff1rlZHaMm7JcY4uUPJmZYtLIfHJg23cRl7IUMVw9MT7LQxRNnWNwBAG7s4tL2q7Rp7hFOLmSaeYbbFbhFoFDgIN3Pc9BfXGLmO53lpEq8QVlbsW7xuC/hTHpZbnSeTwLgOhDDuAKoJs37AwB7AXSIyACAvwGwV0T2oOSGpwF84RKO0TAuGytVVvyXSzAWw1hz2Ey6YQSwquXu0aiPzb2VAV1rB8/s5vMcuH3kD1ihcGKCg0MACMc5SMvleJ+33HIj2bIpDiSHlKUO9tzA793Zu51s0+d12f7hES4lnzw3QDbvat7nXe/fS7asx4HtzCxfnwJfGhx98zAbAZx98xTZukIc3G7wOIHifN7OE95OlJYDpwyysERMnVMUF8NFRZmxUHktnDLbrmF3EMMIwBzEMAIwBzGMAMxBDCOAVQ3Sk6kEXtj/8wpbQQnItvVyufqeO3aR7UyfLhznCQe7k7O8HqFf5Jn4ZIKDxokZDrRffp1npI/38ez64KAepMeV8u3rY7wMgdfAM/EjSln8S/t/TbaCEodGYlxmn5jVVx/ORfj6JOKcDAiHeLs0+PyKSv94aHEZOoCwYssraxkCgCf8HR8K83iyc5XJF19JIqj7r2orw3ibYg5iGAGYgxhGAOYghhHAqgbpsXgYO6+uDETzSrlz1yZtVphLwZMpvdExHOaS7HyR19tLJDmAzitTtm09nDSIxDhID8W5V3z79fp3kF9ke1OYg/xfv8jrKB49yWJyTU3cayOeorqe40qBiWn9OvqO3+8UtfqkokCfyXG/vwjPcEejvJ6gZsso6v4AEI7y34rn8bUtUILAgnTDqBlzEMMIwBzEMAIwBzGMAKrpKNwK4HsANqIU2XzHOfePItIG4IcAelHqKrzPOcfR2gIa6uK4dU9l3/asUpJ97NjrZJuc5l1fv2u3epymxg3amZBlbJwDtXyOt0tO8zJfMymefW5v26TYdMGX2Sx/N8VDHGiH6zlwL+b5mkWFVfLrG1mJ3VMSAdPj59QxtnT3kq01yn8yiUkWzPOFky+xGAffnhK4Fwpcwq61QABAg7LcWlEpIWhorFS69zxddJDGV8U2BQBfcc7tAnA7gC+KyC4ADwJ4zjl3DYDnyq8N44qiGuG4Yefcq+WfkwDeALAFwL0AHi1v9iiAT1yqQRrG5eKCYhAR6QVwC4B9ADY65+ZXchlB6RFMe8/vhOOmJ3mewDDWMlU7iIg0AngcwJedcxUzbM45hyVmXhYKx7W08TOxYaxlqnIQEYmg5Bzfd879pGwendfHKv/PCmeGsc6pJoslKMn8vOGc+8aCXz0J4H4AXyv//8Ry+yr6BSRmKwUQPHBZyEyCsxDHj3PW6FT//6jH6dnGyow37dlJtm3KdnUeZ8CcIgJQVPpYohHutRCuhAAA1Gf4httdz2O8ZQ9naTqaudzjpRdeIltiirWQtf6b8UH9u801cH9K8VoeI5TrowlnxMJ8MTIpLknxi9z7EY3r3+UhRXEzl1GUKRZXGlVXaVJVLdadAD4D4LCIHCzbHkLJMX4kIg8AOAPgvuoOaRjrh2qE416ENolQ4oMXdziGsbawmXTDCMAcxDACWNV+EE+A+milTzqfg6w7b38X2XbuvIFs/WdOq8cZG2fRhukJRSY/wgmC0QwnA1paOHBvauKSDRdRylRmuG8EANoaeN3Dzi7uO0lu5cB//29+Q7aJaVZ/9JVrqyHcKgMAaGvjX7Rt4XKYlPI1G1HEFKLachXC0XImw6U0ztOj6oKizKiddnrRPqu9NnYHMYwAzEEMIwBzEMMIwBzEMAJY1SAd4uCFKoMqL6LI6SsL03ds2kK2G3br6/9lsxzk+Yqq3/D5YbKNJTjYHZsZJdumbg6om5s5qPWX6DuYzfN300T2ZbINTrKwxJFjPGs+l+Vxx+NLRN+LaGjWA+CtbUrvR/Is2bwWPk5LhKsUfHBPhyqw4Pizmk3q1zHkKYG/sgAkTfYvNbO3CLuDGEYA5iCGEYA5iGEEYA5iGAGsapCezc3hxFDlunfNLTwjHctxYLohzs1WrcpsNgDEldJoDywY0NXK5dyRMM9czyR5dj3kOMqbmeby8tFxXnYBABKjrBR5qoPFKnqabyHbH9/3PrId3s/v1dZlbGllEYk5pUwfANw0VwEcOXaIbL2dLBjR3sAl+QVFCXNCKW3fEOHZeqeIOwDAbIIFNeL1/LdSv6FyjJ6nVzgsxu4ghhGAOYhhBGAOYhgBmIMYRgC1KCt+FcDnAMxHsA85554O2lfRL2J6tjIAzxZY1j6mLC2Qb2omW3J2KXU8LmWur+PArbG+m2zxKAecnc1c7p5X1A215RQGTg2pIwwrSxMcGmWFw3PKZPi1US79b1Ouz+YurjTwlPLwbL0eAE9EuFd9CzgxUhfmY9c1KIqQaT6ZfJFVFHNZXqIhn9PXKEwrypyxGB+7tbVS9TIUrk5jpJos1ryy4qsi0gTgFRF5tvy7bzrn/qGqIxnGOqSanvRhAMPln5MiMq+saBhXPLUoKwLAl0TkkIg8LCKqSvNCZcVUgm+nhrGWqUVZ8dsAdgLYg9Id5uva+xYqKzYoVbqGsZapaiZdU1Z0zo0u+P13ATy13H6ikTh6Nl5dYSsoUvWeUq6cyfCs8Ni0rvWrzXxv3c5LE6QVOf5skvfZ2KjMFLcrs/ARFnnbsV1f/6++kQPW/j4u3Y6FlSUMuvmatWzkRMLsLM8yh4ocAO+88WqyAYB/nMvO8wUedzymLEHg8RjbG3m7cITPeeo8Vx+Iz/oBAJDO8FNJOMbbeqHKP3VtvUSNZe8gSykrzsuOlvkkgCNVHdEw1hG1KCt+WkT2oJT6PQ3gC5dkhIZxGalFWTFwzsMwrgRsJt0wAljVcnfnisgVKoPgWIxLrRvquNy5WOCZ1HSClcEBoKGeA79ingPyyTSvexhX1uDTFNp9jwPYdI5n9rs2aeslAvX1HLBu2qSUiBf5OHM+zx63t3EPeCbB28UjnHAI1fN2ABAf54C8boTPx/M58C+Ckx1eiD/rugb+rNMpTshE4rrQW9FxQsYXDtwzhcoqB1/pe9ewO4hhBGAOYhgBmIMYRgDmIIYRwKoG6UW/iFS6cma54LNoWXKWhdpCwkGtCAe1ANDcxPZ0mvcZUZYEkzAH+KksB9/JIS5t12auoZwfADifM+chRR3e95VgV8m6F9PcIhAOcWCbSnNAnczpffPSzLP40sABfeo8B9V5JQgugI89l+HrmHccZA8MD6pjHBnjSoXOzZwMcOnKJE9RKfvXsDuIYQRgDmIYAZiDGEYA5iCGEYA5iGEEsLqlJr6HfKayVCE1y83z2kLyuRxnaaJKuQcATL3FJSgzKc6C7H7HtWRLjHBGxxO+TOoad0pm6q0+PfsSi3JWrqWNsy/Nrfwd1tzCZTPIcbYrrpSzJGZZJCOd5iwUALiMIvAQ4cxfHlx+4ucVgYYQfy75MGex0nnOTPWfZUELAEgm+G+gpYf7QQpe5Tk66NnFxdgdxDACMAcxjADMQQwjgGpabuMi8rKIvC4iR0Xkb8v2q0Rkn4icEpEfiojyYGwY65tqgvQ5AB9wzs2WxRteFJGfA/hLlITjHhORfwbwAEpKJ0uSz/kYGqgsx/CVwDYa4RKHwWEOnnM5XRAhrCxh0NLKgeTgsFLS4vF4PPD+6pW+Ck2VMRzTpY6OnzpOts1ZHmP4PJdnRCKcIGisZzXBhgZWPMxkOEgPRZfqteAAujHew9t5SsNMhktSpgp8vaWLy3MmZ/mzTs7qY8w6/o7vfScrT+6+ZXvF64OHn1H3t5hl7yCuxHwxUqT8zwH4AIAfl+2PAvhEVUc0jHVEVTGIiITKgg1jAJ4F0Adg2jk3nwccwBJqiwuF49KzejrRMNYqVTmIc67onNsDoAfAbQCur/YAC4Xj6hstTDHWFxeUxXLOTQN4HsB7ALSI/G4GrQeAPiNmGOuYapY/6ASQd85Ni0gdgLsB/D1KjvJHAB4DcD+AJ5bb19xcHn19w5X7V5YqaGpk28wU+3IyqT+y7drNsv+921kJcWDoNB+7iSWGXZ5nXesbOKCOKYF77zZdwa+tjWeas1meaZ5W1glMTClqlG3Kun557m3xPD5uInVeHWOuyLPz0wkWSdiQ4hn7mBI8Zz3eXyzK2yWSSh9LSv8ub97CTyXxTkW0o7EyOeGUXhmNarJY3QAeFZEQSnecHznnnhKRYwAeE5G/A/AaSuqLhnFFUY1w3CGUFN0X2/tRikcM44rFZtINIwBzEMMIQJyrruz3ohxMZBzAGQAdAPTIcP1h57I2We5ctjvnOpfbyao6yO8OKnLAOXfrqh/4EmDnsja5WOdij1iGEYA5iGEEcLkc5DuX6biXAjuXtclFOZfLEoMYxnrBHrEMIwBzEMMIYNUdRETuEZE3y626D6728WtBRB4WkTERObLA1iYiz4rIyfL/XO24BhGRrSLyvIgcK7dS/0XZvu7O51K2ha+qg5QLHr8F4KMAdqG0Uu6u1RxDjTwC4J5FtgcBPOecuwbAc+XX64ECgK8453YBuB3AF8ufxXo8n/m28JsB7AFwj4jcjlLV+Tedc1cDmEKpLfyCWO07yG0ATjnn+p1zOZRK5e9d5TGsGOfcCwAWN8Lfi1LLMbCOWo+dc8POuVfLPycBvIFSV+i6O59L2Ra+2g6yBcBCibwlW3XXERudc/NNLiMANl7OwawEEelFqWJ7H9bp+dTSFh6EBekXEVfKma+rvLmINAJ4HMCXnauUMVlP51NLW3gQq+0ggwC2Lnh9JbTqjopINwCU/2ex4TVKWcbpcQDfd879pGxet+cDXPy28NV2kP0ArilnF6IAPgXgyVUew8XmSZRajoEqW4/XAiIiKHWBvuGc+8aCX6278xGRThFpKf883xb+Bv6/LRxY6bk451b1H4CPATiB0jPiX6328Wsc+w8ADAPIo/RM+wCAdpSyPScB/BJA2+UeZ5Xn8l6UHp8OAThY/vex9Xg+AG5Cqe37EIAjAP66bN8B4GUApwD8O4DYhe7bSk0MIwAL0g0jAHMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAL4P/reBAlsXKWPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_path = './work/out/img.png' # 图片路径\n",
    "model_path = './work/out/ssrnet-best' # 模型路径\n",
    "\n",
    "# 加载图像\n",
    "def load_image(image_path):\n",
    "    \"\"\"\n",
    "    功能:\n",
    "        读取图像并转换到输入格式\n",
    "    输入:\n",
    "        image_path - 输入图像路径\n",
    "    输出:\n",
    "        image - 输出图像\n",
    "    \"\"\"\n",
    "    # 读取图像\n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    \n",
    "    # 转换格式\n",
    "    image = image.resize((32, 32), Image.ANTIALIAS) # 调整图像大小\n",
    "    image = np.array(image, dtype=np.float32) # 转换数据格式，数据类型转换为float32\n",
    "\n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    image = (image/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    image = image.transpose((2, 0, 1)).astype(np.float32) # 数据格式从HWC转换为CHW，数据类型转换为float32\n",
    "    \n",
    "    # 增加维度\n",
    "    image = np.expand_dims(image, axis=0) # 增加数据维度\n",
    "    \n",
    "    return image\n",
    "\n",
    "# 预测图像\n",
    "with fluid.dygraph.guard():\n",
    "    # 读取图像\n",
    "    image = load_image(image_path)\n",
    "    image = fluid.dygraph.to_variable(image)\n",
    "    \n",
    "    # 加载模型\n",
    "    model = SSRNet()                               # 加载模型\n",
    "    model_dict, _ = fluid.load_dygraph(model_path) # 加载权重\n",
    "    model.set_dict(model_dict)                     # 设置权重\n",
    "    model.eval()                                   # 设置验证\n",
    "    \n",
    "    # 前向传播\n",
    "    infer_time = time.time()              # 推断开始时间\n",
    "    infer = model(image)\n",
    "    infer_time = time.time() - infer_time # 推断结束时间\n",
    "    \n",
    "    # 显示结果\n",
    "    vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "             'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "             'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "             'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "             'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "             'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "             'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "             'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "             'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "             'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "             'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "             'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "             'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "             'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "             'baby', 'boy', 'girl', 'man', 'woman',\n",
    "             'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "             'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "             'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "             'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "             'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "    vlist.sort() # 字母上升排序\n",
    "    print('infer time: {:f}s, infer value: {}'.format(infer_time, vlist[np.argmax(infer.numpy())]) )\n",
    "    \n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    plt.figure(figsize=(3, 3))     # 设置显示大小\n",
    "    plt.imshow(image)              # 设置显示图像\n",
    "    plt.show()                     # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
